From ab82302c9590783d685c467ad2d1967c2fbbee26 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:29:15 +0200 Subject: [PATCH 01/81] [client] Remove usage of custom dialer for localhost (#2639) * Downgrade error log level for network monitor warnings * Do not use custom dialer for localhost --- client/internal/networkmonitor/monitor_bsd.go | 10 +++++----- client/internal/wgproxy/proxy_userspace.go | 5 ++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index 51135a729e7..4dc2c1aa304 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { - log.Errorf("Network monitor: failed to close routing socket: %v", err) + log.Warnf("Network monitor: failed to close routing socket: %v", err) } }() @@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca <-ctx.Done() err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { - log.Debugf("Network monitor: closed routing socket") + log.Debugf("Network monitor: closed routing socket: %v", err) } }() @@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca n, err := unix.Read(fd, buf) if err != nil { if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Errorf("Network monitor: failed to read from routing socket: %v", err) + log.Warnf("Network monitor: failed to read from routing socket: %v", err) } continue } if n < unix.SizeofRtMsghdr { - log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) continue } @@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca case unix.RTM_ADD, syscall.RTM_DELETE: route, err := parseRouteMessage(buf[:n]) if err != nil { - log.Errorf("Network monitor: error parsing routing message: %v", err) + log.Debugf("Network monitor: error parsing routing message: %v", err) continue } diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go index 701f615b9f5..8fc640b6ad1 100644 --- a/client/internal/wgproxy/proxy_userspace.go +++ b/client/internal/wgproxy/proxy_userspace.go @@ -7,8 +7,6 @@ import ( "net" log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" ) // WGUserSpaceProxy proxies @@ -36,7 +34,8 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { p.remoteConn = remoteConn var err error - p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + dialer := &net.Dialer{} + p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) return nil, err From e7d52c8c95aa0a0520442bc6c984be2343d70ee8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:57:56 +0200 Subject: [PATCH 02/81] [client] Fix error count formatting (#2641) --- client/errors/errors.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/errors/errors.go b/client/errors/errors.go index cef999ac872..8faadbda5ff 100644 --- a/client/errors/errors.go +++ b/client/errors/errors.go @@ -8,8 +8,8 @@ import ( ) func formatError(es []error) string { - if len(es) == 0 { - return fmt.Sprintf("0 error occurred:\n\t* %s", es[0]) + if len(es) == 1 { + return fmt.Sprintf("1 error occurred:\n\t* %s", es[0]) } points := make([]string, len(es)) From b51d75204b13a191ecb2ab01ad81d1a73b49b5c5 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 24 Sep 2024 20:58:18 +0200 Subject: [PATCH 03/81] [client] Anonymize relay address in status peers view (#2640) --- client/cmd/status.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/cmd/status.go b/client/cmd/status.go index 1ef8b49138b..ed3daa2b5fd 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) } + + peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + for i, route := range peer.Routes { peer.Routes[i] = a.AnonymizeIPString(route) } From 1e4a0f77e27710e57c66ef775f9ccd1e97e82a84 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 25 Sep 2024 18:22:27 +0200 Subject: [PATCH 04/81] Add get DB method to store (#2650) --- management/server/sql_store.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 6f1f66ef81b..8fa5f9d0588 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1024,3 +1024,7 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { db: tx, } } + +func (s *SqlStore) GetDB() *gorm.DB { + return s.db +} From 4ebf6e1c4c5b549ad6983b2a7a36874fd8a85dc4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 25 Sep 2024 18:50:10 +0200 Subject: [PATCH 05/81] [client] Close the remote conn in proxy (#2626) Port the conn close call to eBPF proxy --- client/internal/engine.go | 2 +- client/internal/peer/conn.go | 8 +- client/internal/peer/conn_test.go | 8 +- .../internal/wgproxy/{ => ebpf}/portlookup.go | 2 +- .../wgproxy/{ => ebpf}/portlookup_test.go | 2 +- .../wgproxy/{proxy_ebpf.go => ebpf/proxy.go} | 165 ++++++++++-------- .../proxy_test.go} | 9 +- client/internal/wgproxy/ebpf/wrapper.go | 44 +++++ client/internal/wgproxy/factory.go | 22 --- client/internal/wgproxy/factory_linux.go | 33 +++- client/internal/wgproxy/factory_nonlinux.go | 16 +- client/internal/wgproxy/proxy.go | 6 +- client/internal/wgproxy/proxy_test.go | 128 ++++++++++++++ client/internal/wgproxy/proxy_userspace.go | 129 -------------- client/internal/wgproxy/usp/proxy.go | 146 ++++++++++++++++ relay/client/picker_test.go | 2 +- 16 files changed, 467 insertions(+), 255 deletions(-) rename client/internal/wgproxy/{ => ebpf}/portlookup.go (96%) rename client/internal/wgproxy/{ => ebpf}/portlookup_test.go (97%) rename client/internal/wgproxy/{proxy_ebpf.go => ebpf/proxy.go} (65%) rename client/internal/wgproxy/{proxy_ebpf_test.go => ebpf/proxy_test.go} (86%) create mode 100644 client/internal/wgproxy/ebpf/wrapper.go delete mode 100644 client/internal/wgproxy/factory.go create mode 100644 client/internal/wgproxy/proxy_test.go delete mode 100644 client/internal/wgproxy/proxy_userspace.go create mode 100644 client/internal/wgproxy/usp/proxy.go diff --git a/client/internal/engine.go b/client/internal/engine.go index b0deb5a29a2..463507ad89a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -292,7 +292,7 @@ func (e *Engine) Start() error { e.wgInterface = wgIface userspace := e.wgInterface.IsUserspaceBind() - e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort) + e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort) if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 911ddd2281c..ea6d892b9f6 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -527,8 +527,8 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { conn.log.Debugf("Relay connection is ready to use") conn.statusRelay.Set(StatusConnected) - wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) - endpoint, err := wgProxy.AddTurnConn(rci.relayedConn) + wgProxy := conn.wgProxyFactory.GetProxy() + endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return @@ -775,8 +775,8 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil } conn.log.Debugf("setup ice turn connection") - wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) - ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn) + wgProxy := conn.wgProxyFactory.GetProxy() + ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) if err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) if errClose := wgProxy.CloseConn(); errClose != nil { diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 80c25f63c0e..22e5409f894 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() @@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/ebpf/portlookup.go similarity index 96% rename from client/internal/wgproxy/portlookup.go rename to client/internal/wgproxy/ebpf/portlookup.go index 6f3d33487ea..0e2c20c9911 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/ebpf/portlookup.go @@ -1,4 +1,4 @@ -package wgproxy +package ebpf import ( "fmt" diff --git a/client/internal/wgproxy/portlookup_test.go b/client/internal/wgproxy/ebpf/portlookup_test.go similarity index 97% rename from client/internal/wgproxy/portlookup_test.go rename to client/internal/wgproxy/ebpf/portlookup_test.go index 6a386f33087..92f4b8eee9f 100644 --- a/client/internal/wgproxy/portlookup_test.go +++ b/client/internal/wgproxy/ebpf/portlookup_test.go @@ -1,4 +1,4 @@ -package wgproxy +package ebpf import ( "fmt" diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/ebpf/proxy.go similarity index 65% rename from client/internal/wgproxy/proxy_ebpf.go rename to client/internal/wgproxy/ebpf/proxy.go index d385cc4caca..4bd4bfff624 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -1,6 +1,6 @@ //go:build linux && !android -package wgproxy +package ebpf import ( "context" @@ -13,47 +13,49 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/hashicorp/go-multierror" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" nbnet "github.com/netbirdio/netbird/util/net" ) +const ( + loopbackAddr = "127.0.0.1" +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { - ebpfManager ebpfMgr.Manager - - ctx context.Context - cancel context.CancelFunc - - lastUsedPort uint16 localWGListenPort int + ebpfManager ebpfMgr.Manager turnConnStore map[uint16]net.Conn turnConnMutex sync.Mutex - rawConn net.PacketConn - conn transport.UDPConn + lastUsedPort uint16 + rawConn net.PacketConn + conn transport.UDPConn + + ctx context.Context + ctxCancel context.CancelFunc } // NewWGEBPFProxy create new WGEBPFProxy instance -func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { +func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { log.Debugf("instantiate ebpf proxy") wgProxy := &WGEBPFProxy{ localWGListenPort: wgPort, ebpfManager: ebpf.GetEbpfManagerInstance(), - lastUsedPort: 0, turnConnStore: make(map[uint16]net.Conn), } - wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx) - return wgProxy } -// listen load ebpf program and listen the proxy -func (p *WGEBPFProxy) listen() error { +// Listen load ebpf program and listen the proxy +func (p *WGEBPFProxy) Listen() error { pl := portLookup{} wgPorxyPort, err := pl.searchFreePort() if err != nil { @@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error { addr := net.UDPAddr{ Port: wgPorxyPort, - IP: net.ParseIP("127.0.0.1"), + IP: net.ParseIP(loopbackAddr), } + p.ctx, p.ctxCancel = context.WithCancel(context.Background()) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() @@ -91,106 +95,110 @@ func (p *WGEBPFProxy) listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(wgEndpointPort, turnConn) + go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), + IP: net.ParseIP(loopbackAddr), Port: int(wgEndpointPort), } return wgEndpoint, nil } -// CloseConn doing nothing because this type of proxy implementation does not store the connection -func (p *WGEBPFProxy) CloseConn() error { - return nil -} - -// Free resources +// Free resources except the remoteConns will be keep open. func (p *WGEBPFProxy) Free() error { log.Debugf("free up ebpf wg proxy") - var err1, err2, err3 error - if p.conn != nil { - err1 = p.conn.Close() + if p.ctx != nil && p.ctx.Err() != nil { + //nolint + return nil } - err2 = p.ebpfManager.FreeWGProxy() - if p.rawConn != nil { - err3 = p.rawConn.Close() - } + p.ctxCancel() - if err1 != nil { - return err1 + var result *multierror.Error + if err := p.conn.Close(); err != nil { + result = multierror.Append(result, err) } - if err2 != nil { - return err2 + if err := p.ebpfManager.FreeWGProxy(); err != nil { + result = multierror.Append(result, err) } - return err3 + if err := p.rawConn.Close(); err != nil { + result = multierror.Append(result, err) + } + return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { +func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { + defer p.removeTurnConn(endpointPort) + + var ( + err error + n int + ) buf := make([]byte, 1500) - var err error - defer func() { - p.removeTurnConn(endpointPort) - }() - for { - select { - case <-p.ctx.Done(): - return - default: - var n int - n, err = remoteConn.Read(buf) - if err != nil { - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } + for ctx.Err() == nil { + n, err = remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { return } - err = p.sendPkg(buf[:n], endpointPort) - if err != nil { - log.Errorf("failed to write out turn pkg to local conn: %v", err) + if err != io.EOF { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) + } + return + } + + if err := p.sendPkg(buf[:n], endpointPort); err != nil { + if ctx.Err() != nil || p.ctx.Err() != nil { + return } + log.Errorf("failed to write out turn pkg to local conn: %v", err) } } } // proxyToRemote read messages from local WireGuard interface and forward it to remote conn +// From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { buf := make([]byte, 1500) - for { - select { - case <-p.ctx.Done(): - return - default: - n, addr, err := p.conn.ReadFromUDP(buf) - if err != nil { - log.Errorf("failed to read UDP pkg from WG: %s", err) + for p.ctx.Err() == nil { + if err := p.readAndForwardPacket(buf); err != nil { + if p.ctx.Err() != nil { return } + log.Errorf("failed to proxy packet to remote conn: %s", err) + } + } +} - p.turnConnMutex.Lock() - conn, ok := p.turnConnStore[uint16(addr.Port)] - p.turnConnMutex.Unlock() - if !ok { - log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port) - continue - } +func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error { + n, addr, err := p.conn.ReadFromUDP(buf) + if err != nil { + return fmt.Errorf("failed to read UDP packet from WG: %w", err) + } - _, err = conn.Write(buf[:n]) - if err != nil { - log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err) - } + p.turnConnMutex.Lock() + conn, ok := p.turnConnStore[uint16(addr.Port)] + p.turnConnMutex.Unlock() + if !ok { + if p.ctx.Err() == nil { + log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port) } + return nil + } + + if _, err := conn.Write(buf[:n]); err != nil { + return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err) } + return nil } func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { @@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { } func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) { - log.Debugf("remove turn conn from store by port: %d", turnConnID) p.turnConnMutex.Lock() defer p.turnConnMutex.Unlock() - delete(p.turnConnStore, turnConnID) + _, ok := p.turnConnStore[turnConnID] + if ok { + log.Debugf("remove turn conn from store by port: %d", turnConnID) + } + delete(p.turnConnStore, turnConnID) } func (p *WGEBPFProxy) nextFreePort() (uint16, error) { diff --git a/client/internal/wgproxy/proxy_ebpf_test.go b/client/internal/wgproxy/ebpf/proxy_test.go similarity index 86% rename from client/internal/wgproxy/proxy_ebpf_test.go rename to client/internal/wgproxy/ebpf/proxy_test.go index 821e64218db..b15bc686c0d 100644 --- a/client/internal/wgproxy/proxy_ebpf_test.go +++ b/client/internal/wgproxy/ebpf/proxy_test.go @@ -1,14 +1,13 @@ //go:build linux && !android -package wgproxy +package ebpf import ( - "context" "testing" ) func TestWGEBPFProxy_connStore(t *testing.T) { - wgProxy := NewWGEBPFProxy(context.Background(), 1) + wgProxy := NewWGEBPFProxy(1) p, _ := wgProxy.storeTurnConn(nil) if p != 1 { @@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) { } func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { - wgProxy := NewWGEBPFProxy(context.Background(), 1) + wgProxy := NewWGEBPFProxy(1) _, _ = wgProxy.storeTurnConn(nil) wgProxy.lastUsedPort = 65535 @@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { } func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { - wgProxy := NewWGEBPFProxy(context.Background(), 1) + wgProxy := NewWGEBPFProxy(1) for i := 0; i < 65535; i++ { _, _ = wgProxy.storeTurnConn(nil) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go new file mode 100644 index 00000000000..c5639f840cc --- /dev/null +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -0,0 +1,44 @@ +//go:build linux && !android + +package ebpf + +import ( + "context" + "fmt" + "net" +) + +// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call +type ProxyWrapper struct { + WgeBPFProxy *WGEBPFProxy + + remoteConn net.Conn + cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread +} + +func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { + ctxConn, cancel := context.WithCancel(ctx) + addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) + + if err != nil { + cancel() + return nil, fmt.Errorf("add turn conn: %w", err) + } + e.remoteConn = remoteConn + e.cancel = cancel + return addr, err +} + +// CloseConn close the remoteConn and automatically remove the conn instance from the map +func (e *ProxyWrapper) CloseConn() error { + if e.cancel == nil { + return fmt.Errorf("proxy not started") + } + + e.cancel() + + if err := e.remoteConn.Close(); err != nil { + return fmt.Errorf("failed to close remote conn: %w", err) + } + return nil +} diff --git a/client/internal/wgproxy/factory.go b/client/internal/wgproxy/factory.go deleted file mode 100644 index f4eb150b073..00000000000 --- a/client/internal/wgproxy/factory.go +++ /dev/null @@ -1,22 +0,0 @@ -package wgproxy - -import "context" - -type Factory struct { - wgPort int - ebpfProxy Proxy -} - -func (w *Factory) GetProxy(ctx context.Context) Proxy { - if w.ebpfProxy != nil { - return w.ebpfProxy - } - return NewWGUserSpaceProxy(ctx, w.wgPort) -} - -func (w *Factory) Free() error { - if w.ebpfProxy != nil { - return w.ebpfProxy.Free() - } - return nil -} diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go index d01ae7e742c..369ba99db1f 100644 --- a/client/internal/wgproxy/factory_linux.go +++ b/client/internal/wgproxy/factory_linux.go @@ -3,20 +3,26 @@ package wgproxy import ( - "context" - log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/internal/wgproxy/usp" ) -func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { +type Factory struct { + wgPort int + ebpfProxy *ebpf.WGEBPFProxy +} + +func NewFactory(userspace bool, wgPort int) *Factory { f := &Factory{wgPort: wgPort} if userspace { return f } - ebpfProxy := NewWGEBPFProxy(ctx, wgPort) - err := ebpfProxy.listen() + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) + err := ebpfProxy.Listen() if err != nil { log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) return f @@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { f.ebpfProxy = ebpfProxy return f } + +func (w *Factory) GetProxy() Proxy { + if w.ebpfProxy != nil { + p := &ebpf.ProxyWrapper{ + WgeBPFProxy: w.ebpfProxy, + } + return p + } + return usp.NewWGUserSpaceProxy(w.wgPort) +} + +func (w *Factory) Free() error { + if w.ebpfProxy == nil { + return nil + } + return w.ebpfProxy.Free() +} diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go index d1640c97dd0..f930b09b3a0 100644 --- a/client/internal/wgproxy/factory_nonlinux.go +++ b/client/internal/wgproxy/factory_nonlinux.go @@ -2,8 +2,20 @@ package wgproxy -import "context" +import "github.com/netbirdio/netbird/client/internal/wgproxy/usp" -func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory { +type Factory struct { + wgPort int +} + +func NewFactory(_ bool, wgPort int) *Factory { return &Factory{wgPort: wgPort} } + +func (w *Factory) GetProxy() Proxy { + return usp.NewWGUserSpaceProxy(w.wgPort) +} + +func (w *Factory) Free() error { + return nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index b88df73a092..96fae8dd103 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -1,12 +1,12 @@ package wgproxy import ( + "context" "net" ) -// Proxy is a transfer layer between the Turn connection and the WireGuard +// Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) CloseConn() error - Free() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go new file mode 100644 index 00000000000..b09e6be555f --- /dev/null +++ b/client/internal/wgproxy/proxy_test.go @@ -0,0 +1,128 @@ +//go:build linux + +package wgproxy + +import ( + "context" + "io" + "net" + "os" + "runtime" + "testing" + "time" + + "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/internal/wgproxy/usp" + "github.com/netbirdio/netbird/util" +) + +func TestMain(m *testing.M) { + _ = util.InitLog("trace", "console") + code := m.Run() + os.Exit(code) +} + +type mocConn struct { + closeChan chan struct{} + closed bool +} + +func newMockConn() *mocConn { + return &mocConn{ + closeChan: make(chan struct{}), + } +} + +func (m *mocConn) Read(b []byte) (n int, err error) { + <-m.closeChan + return 0, io.EOF +} + +func (m *mocConn) Write(b []byte) (n int, err error) { + <-m.closeChan + return 0, io.EOF +} + +func (m *mocConn) Close() error { + if m.closed == true { + return nil + } + + m.closed = true + close(m.closeChan) + return nil +} + +func (m *mocConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (m *mocConn) RemoteAddr() net.Addr { + return &net.UDPAddr{ + IP: net.ParseIP("172.16.254.1"), + } +} + +func (m *mocConn) SetDeadline(t time.Time) error { + panic("implement me") +} + +func (m *mocConn) SetReadDeadline(t time.Time) error { + panic("implement me") +} + +func (m *mocConn) SetWriteDeadline(t time.Time) error { + panic("implement me") +} + +func TestProxyCloseByRemoteConn(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + proxy Proxy + }{ + { + name: "userspace proxy", + proxy: usp.NewWGUserSpaceProxy(51830), + }, + } + + if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { + ebpfProxy := ebpf.NewWGEBPFProxy(51831) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %s", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %s", err) + } + }() + proxyWrapper := &ebpf.ProxyWrapper{ + WgeBPFProxy: ebpfProxy, + } + + tests = append(tests, struct { + name string + proxy Proxy + }{ + name: "ebpf proxy", + proxy: proxyWrapper, + }) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + relayedConn := newMockConn() + _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + if err != nil { + t.Errorf("error: %v", err) + } + + _ = relayedConn.Close() + if err := tt.proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go deleted file mode 100644 index 8fc640b6ad1..00000000000 --- a/client/internal/wgproxy/proxy_userspace.go +++ /dev/null @@ -1,129 +0,0 @@ -package wgproxy - -import ( - "context" - "fmt" - "io" - "net" - - log "github.com/sirupsen/logrus" -) - -// WGUserSpaceProxy proxies -type WGUserSpaceProxy struct { - localWGListenPort int - ctx context.Context - cancel context.CancelFunc - - remoteConn net.Conn - localConn net.Conn -} - -// NewWGUserSpaceProxy instantiate a user space WireGuard proxy -func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy { - log.Debugf("Initializing new user space proxy with port %d", wgPort) - p := &WGUserSpaceProxy{ - localWGListenPort: wgPort, - } - p.ctx, p.cancel = context.WithCancel(ctx) - return p -} - -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) { - p.remoteConn = remoteConn - - var err error - dialer := &net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) - if err != nil { - log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err - } - - go p.proxyToRemote() - go p.proxyToLocal() - - return p.localConn.LocalAddr(), err -} - -// CloseConn close the localConn -func (p *WGUserSpaceProxy) CloseConn() error { - p.cancel() - if p.localConn == nil { - return nil - } - - if p.remoteConn == nil { - return nil - } - - if err := p.remoteConn.Close(); err != nil { - log.Warnf("failed to close remote conn: %s", err) - } - return p.localConn.Close() -} - -// Free doing nothing because this implementation of proxy does not have global state -func (p *WGUserSpaceProxy) Free() error { - return nil -} - -// proxyToRemote proxies everything from Wireguard to the RemoteKey peer -// blocks -func (p *WGUserSpaceProxy) proxyToRemote() { - defer log.Infof("exit from proxyToRemote: %s", p.localConn.LocalAddr()) - - buf := make([]byte, 1500) - for { - select { - case <-p.ctx.Done(): - return - default: - n, err := p.localConn.Read(buf) - if err != nil { - log.Debugf("failed to read from wg interface conn: %s", err) - continue - } - - _, err = p.remoteConn.Write(buf[:n]) - if err != nil { - if err == io.EOF { - p.cancel() - } else { - log.Debugf("failed to write to remote conn: %s", err) - } - continue - } - } - } -} - -// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard -// blocks -func (p *WGUserSpaceProxy) proxyToLocal() { - defer p.cancel() - defer log.Infof("exit from proxyToLocal: %s", p.localConn.LocalAddr()) - buf := make([]byte, 1500) - for { - select { - case <-p.ctx.Done(): - return - default: - n, err := p.remoteConn.Read(buf) - if err != nil { - if err == io.EOF { - return - } - log.Errorf("failed to read from remote conn: %s", err) - continue - } - - _, err = p.localConn.Write(buf[:n]) - if err != nil { - log.Debugf("failed to write to wg interface conn: %s", err) - continue - } - } - } -} diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go new file mode 100644 index 00000000000..83a8725d899 --- /dev/null +++ b/client/internal/wgproxy/usp/proxy.go @@ -0,0 +1,146 @@ +package usp + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/errors" +) + +// WGUserSpaceProxy proxies +type WGUserSpaceProxy struct { + localWGListenPort int + ctx context.Context + cancel context.CancelFunc + + remoteConn net.Conn + localConn net.Conn + closeMu sync.Mutex + closed bool +} + +// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation +func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { + log.Debugf("Initializing new user space proxy with port %d", wgPort) + p := &WGUserSpaceProxy{ + localWGListenPort: wgPort, + } + return p +} + +// AddTurnConn start the proxy with the given remote conn +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { + p.ctx, p.cancel = context.WithCancel(ctx) + + p.remoteConn = remoteConn + + var err error + dialer := net.Dialer{} + p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + if err != nil { + log.Errorf("failed dialing to local Wireguard port %s", err) + return nil, err + } + + go p.proxyToRemote() + go p.proxyToLocal() + + return p.localConn.LocalAddr(), err +} + +// CloseConn close the localConn +func (p *WGUserSpaceProxy) CloseConn() error { + if p.cancel == nil { + return fmt.Errorf("proxy not started") + } + return p.close() +} + +func (p *WGUserSpaceProxy) close() error { + p.closeMu.Lock() + defer p.closeMu.Unlock() + + // prevent double close + if p.closed { + return nil + } + p.closed = true + + p.cancel() + + var result *multierror.Error + if err := p.remoteConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) + } + + if err := p.localConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) + } + return errors.FormatErrorOrNil(result) +} + +// proxyToRemote proxies from Wireguard to the RemoteKey +func (p *WGUserSpaceProxy) proxyToRemote() { + defer func() { + if err := p.close(); err != nil { + log.Warnf("error in proxy to remote loop: %s", err) + } + }() + + buf := make([]byte, 1500) + for p.ctx.Err() == nil { + n, err := p.localConn.Read(buf) + if err != nil { + if p.ctx.Err() != nil { + return + } + log.Debugf("failed to read from wg interface conn: %s", err) + return + } + + _, err = p.remoteConn.Write(buf[:n]) + if err != nil { + if p.ctx.Err() != nil { + return + } + + log.Debugf("failed to write to remote conn: %s", err) + return + } + } +} + +// proxyToLocal proxies from the Remote peer to local WireGuard +func (p *WGUserSpaceProxy) proxyToLocal() { + defer func() { + if err := p.close(); err != nil { + log.Warnf("error in proxy to local loop: %s", err) + } + }() + + buf := make([]byte, 1500) + for p.ctx.Err() == nil { + n, err := p.remoteConn.Read(buf) + if err != nil { + if p.ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) + return + } + + _, err = p.localConn.Write(buf[:n]) + if err != nil { + if p.ctx.Err() != nil { + return + } + log.Debugf("failed to write to wg interface conn: %s", err) + continue + } + } +} diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index f5649d700a4..eb14581e067 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -13,7 +13,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) { PeerID: "test", } - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() go func() { From acb73bd64abeca70d1ae39d0c088fbf2ffcb80aa Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 27 Sep 2024 17:10:50 +0300 Subject: [PATCH 06/81] [management] Remove redundant get account calls in GetAccountFromToken (#2615) * refactor access control middleware and user access by JWT groups Signed-off-by: bcmmbaga * refactor jwt groups extractor Signed-off-by: bcmmbaga * refactor handlers to get account when necessary Signed-off-by: bcmmbaga * refactor getAccountFromToken Signed-off-by: bcmmbaga * refactor getAccountWithAuthorizationClaims Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * revert handles change Signed-off-by: bcmmbaga * remove GetUserByID from account manager Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * refactor getAccountWithAuthorizationClaims to return account id Signed-off-by: bcmmbaga * refactor handlers to use GetAccountIDFromToken Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * remove locks Signed-off-by: bcmmbaga * refactor Signed-off-by: bcmmbaga * add GetGroupByName from store Signed-off-by: bcmmbaga * add GetGroupByID from store and refactor Signed-off-by: bcmmbaga * Refactor retrieval of policy and posture checks Signed-off-by: bcmmbaga * Refactor user permissions and retrieves PAT Signed-off-by: bcmmbaga * Refactor route, setupkey, nameserver and dns to get record(s) from store Signed-off-by: bcmmbaga * Refactor store Signed-off-by: bcmmbaga * fix lint Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * fix add missing policy source posture checks Signed-off-by: bcmmbaga * add store lock Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * add get account Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 383 +++++++++++------- management/server/account_test.go | 107 +++-- management/server/dns.go | 16 +- management/server/file_store.go | 101 ++++- management/server/group.go | 94 ++--- management/server/grpcserver.go | 2 +- management/server/http/accounts_handler.go | 46 +-- .../server/http/accounts_handler_test.go | 7 +- .../server/http/dns_settings_handler.go | 8 +- .../server/http/dns_settings_handler_test.go | 4 +- management/server/http/events_handler.go | 6 +- management/server/http/events_handler_test.go | 14 +- .../server/http/geolocation_handler_test.go | 15 +- .../server/http/geolocations_handler.go | 7 +- management/server/http/groups_handler.go | 111 ++--- management/server/http/groups_handler_test.go | 63 ++- management/server/http/nameservers_handler.go | 20 +- .../server/http/nameservers_handler_test.go | 13 +- management/server/http/pat_handler.go | 18 +- management/server/http/pat_handler_test.go | 6 +- management/server/http/peers_handler.go | 52 ++- management/server/http/peers_handler_test.go | 28 +- management/server/http/policies_handler.go | 152 ++++--- .../server/http/policies_handler_test.go | 16 +- .../server/http/posture_checks_handler.go | 46 +-- .../http/posture_checks_handler_test.go | 12 +- management/server/http/routes_handler.go | 40 +- management/server/http/routes_handler_test.go | 16 +- management/server/http/setupkeys_handler.go | 18 +- .../server/http/setupkeys_handler_test.go | 18 +- management/server/http/users_handler.go | 32 +- management/server/http/users_handler_test.go | 7 +- management/server/mock_server/account_mock.go | 70 +++- management/server/nameserver.go | 42 +- management/server/peer_test.go | 4 +- management/server/policy.go | 107 +++-- management/server/posture_checks.go | 34 +- management/server/route.go | 52 +-- management/server/route_test.go | 2 +- management/server/setupkey.go | 56 +-- management/server/sql_store.go | 181 ++++++++- management/server/store.go | 73 +++- management/server/user.go | 86 ++-- management/server/user_test.go | 11 +- 44 files changed, 1247 insertions(+), 949 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 208315643a6..710b6f62f35 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,11 +20,6 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -41,6 +36,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration { type AccountManager interface { GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + GetAccount(ctx context.Context, accountID string) (*Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) @@ -75,12 +75,14 @@ type AccountManager interface { SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) - GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error + GetUserByID(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -107,7 +109,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -145,6 +147,7 @@ type AccountManager interface { SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) } type DefaultAccountManager struct { @@ -268,6 +271,11 @@ type AccountNetwork struct { Network *Network `gorm:"embedded;embeddedPrefix:network_"` } +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } @@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and -// userID doesn't have an account associated with it, one account is created -// domain is used to create a new account if no account is found -func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { +// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. +// If an accountID is provided, it checks if the account exists and returns it. +// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// If the user doesn't have an account, it creates one using the provided domain. +// Returns the account ID or an error if none is found or created. +func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { if accountID != "" { - return am.Store.GetAccount(ctx, accountID) - } else if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) + return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, userID, account) + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) + } + return accountID, nil + } + + if userID != "" { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil { - return nil, err + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } + + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err } - return account, nil + + return account.Id, nil } - return nil, status.Errorf(status.NotFound, "no valid user or account Id provided") + return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") } func isNil(i idp.Manager) bool { @@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err @@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string return am.Store.SaveAccount(ctx, account) } +// GetAccount returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { + return am.Store.GetAccount(ctx, accountID) +} + // GetAccountFromPAT returns Account and User associated with a personal access token func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { @@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st return account, user, pat, nil } -// GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +// GetAccountByID returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccount(ctx, accountID) +} + +// GetAccountIDFromToken returns an account ID associated with this token. +func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return nil, nil, fmt.Errorf("user ID is empty") + return "", "", fmt.Errorf("user ID is empty") } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) + accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) if err != nil { - return nil, nil, err + return "", "", err } - unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id) - alreadyUnlocked := false - defer func() { - if !alreadyUnlocked { - unlock() - } - }() - account, err := am.Store.GetAccount(ctx, newAcc.Id) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { - return nil, nil, err - } - - user := account.Users[claims.UserId] - if user == nil { // this is not really possible because we got an account by user ID - return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(ctx, account, claims.UserId) + err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { - return nil, nil, err + return "", "", err } } - if account.Settings.JWTGroupsEnabled { - if account.Settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") - return account, user, nil + if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + return "", "", err + } + + return accountID, user.Id, nil +} + +// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, +// and propagates changes to peers if group propagation is enabled. +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if settings == nil || !settings.JWTGroupsEnabled { + return nil + } + + if settings.JWTGroupsClaimName == "" { + log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + return nil + } + + // TODO: Remove GetAccount after refactoring account peer's update + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + + jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) + + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) + + // Update the account if group membership changes + if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) + + if settings.GroupsPropagationEnabled { + account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) + account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) + account.Network.IncSerial() } - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if slice, ok := claim.([]interface{}); ok { - var groupsNames []string - for _, item := range slice { - if g, ok := item.(string); ok { - groupsNames = append(groupsNames, g) - } else { - log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) - } - } - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) - // if groups were added or modified, save the account - if account.SetJWTGroups(claims.UserId, groupsNames) { - if account.Settings.GroupsPropagationEnabled { - if user, err := account.FindUser(claims.UserId); err == nil { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } else { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - unlock() - alreadyUnlocked = true - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - } - } - } else { - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } - } - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) + return nil + } + + // Propagate changes to peers if group propagation is enabled + if settings.GroupsPropagationEnabled { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + + for _, g := range addNewGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + + for _, g := range removeOldGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) } - } else { - log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) } } - return account, user, nil + return nil } -// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. +// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // @@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } + // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { - return nil, err + return "", err } - if _, ok := accountFromID.Users[claims.UserId]; !ok { - return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { - return accountFromID, nil + + domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + + if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { + return userAccountID, nil } } @@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { - return nil, err + return "", err } } - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) defer unlockAccount() - account, err = am.Store.GetAccountByUser(ctx, claims.UserId) + account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { - return nil, err + return "", err } // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. - primaryDomain := domainAccount == nil || account.Id == domainAccount.Id - - err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) - if err != nil { - return nil, err + primaryDomain := domainAccountID == "" || account.Id == domainAccountID + if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { + return "", err } - return account, nil + + return account.Id, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - if domainAccount != nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) + var domainAccount *Account + if domainAccountID != "" { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { - return nil, err + return "", err } } - return am.handleNewUserAccount(ctx, domainAccount, claims) + + account, err := am.handleNewUserAccount(ctx, domainAccount, claims) + if err != nil { + return "", err + } + return account.Id, nil } else { // other error - return nil, err + return "", err } } @@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, _, err := am.GetAccountIDFromToken(ctx, claims) + if err != nil { + return err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } // Ensures JWT group synchronization to the management is enabled before, // filtering access based on the allowed groups. - if account.Settings != nil && account.Settings.JWTGroupsEnabled { - if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { - userJWTGroups := make([]string, 0) - - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { - for _, g := range claimGroups { - if group, ok := g.(string); ok { - userJWTGroups = append(userJWTGroups, group) - } - } - } - } + if settings != nil && settings.JWTGroupsEnabled { + if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { + userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) if !userHasAllowedGroup(allowedGroups, userJWTGroups) { return fmt.Errorf("user does not belong to any of the allowed JWT groups") @@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { @@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac return acc } +// extractJWTGroups extracts the group names from a JWT token's claims. +func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { + userJWTGroups := make([]string, 0) + + if claim, ok := claims.Raw[claimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } else { + log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) + } + } + } + } else { + log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName) + } + + return userJWTGroups +} + // userHasAllowedGroup checks if a user belongs to any of the allowed groups. func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { for _, userGroup := range userGroups { diff --git a/management/server/account_test.go b/management/server/account_test.go index 03b5fa83efd..303261bead6 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { assert.Equal(t, account.Id, ev.TargetID) } -func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { +func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims type test struct { @@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") + if testCase.inputUpdateAttrs { err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") @@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) + accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount = acc + initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount @@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "only ALL group should exists") }) @@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} @@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) return } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } @@ -1240,7 +1260,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1648,19 +1668,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - assert.NotNil(t, account.Settings) - assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) - assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") + + assert.NotNil(t, settings) + assert.Equal(t, settings.PeerLoginExpirationEnabled, true) + assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour) } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1672,11 +1695,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1713,7 +1741,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1724,7 +1752,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1741,8 +1769,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + // when we mark peer as connected, the peer login expiration routine should trigger err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1757,7 +1789,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1769,8 +1801,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1813,10 +1849,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1824,19 +1860,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") require.NoError(t, err, "unable to get account by ID") - assert.False(t, account.Settings.PeerLoginExpirationEnabled) - assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + assert.False(t, settings.PeerLoginExpirationEnabled) + assert.Equal(t, settings.PeerLoginExpiration, time.Hour) + + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) diff --git a/management/server/dns.go b/management/server/dns.go index 1d156c90a62..7410aaa15cc 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings { // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") } - dnsSettings := account.DNSSettings.Copy() - return &dnsSettings, nil + + return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings diff --git a/management/server/file_store.go b/management/server/file_store.go index 95d5b4e6e46..994a4b1eec5 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,14 +10,15 @@ import ( "sync" "time" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - + "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/util" ) @@ -634,10 +635,19 @@ func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID return nil, err } - return account.Users[userID].Copy(), nil + user := account.Users[userID].Copy() + pat := make([]PersonalAccessToken, 0, len(user.PATs)) + for _, token := range user.PATs { + if token != nil { + pat = append(pat, *token) + } + } + user.PATsG = pat + + return user, nil } -func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { +func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { account, err := s.getAccount(accountID) if err != nil { return nil, err @@ -931,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin return nil } -func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { +func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") } @@ -950,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } -func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { +func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { return status.Errorf(status.Internal, "SaveUsers is not implemented") } -func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { +func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } + +func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { + return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") +} + +func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return "", "", err + } + + return account.Domain, account.DomainCategory, nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { + _, exists := s.Accounts[id] + return exists, nil +} + +func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { + return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") +} + +func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") +} + +func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") +} + +func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") +} + +func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") + +} + +func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") +} + +func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") +} + +func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") +} + +func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") +} + +func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") +} + +func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") +} + +func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") +} + +func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") +} diff --git a/management/server/group.go b/management/server/group.go index 49720f34730..aa387c058ea 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -// GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) +// CheckGroupPermissions validates if a user has the necessary permissions to view groups +func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, err + return err } - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return nil, err + return err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "groups are blocked for users") } - group, ok := account.Groups[groupID] - if ok { - return group, nil - } - - return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) + return nil } -// GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - user, err := account.FindUser(userID) - if err != nil { +// GetGroup returns a specific group by groupID in an account +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") - } + return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) +} - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) +// GetAllGroups returns all groups in an account +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { + return nil, err } - return groups, nil + return am.Store.GetAccountGroups(ctx, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - matchingGroups := make([]*nbgroup.Group, 0) - for _, group := range account.Groups { - if group.Name == groupName { - matchingGroups = append(matchingGroups, group) - } - } - - if len(matchingGroups) == 0 { - return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName) - } - - maxPeers := -1 - var groupWithMostPeers *nbgroup.Group - for i, group := range matchingGroups { - if len(group.Peers) > maxPeers { - maxPeers = len(group.Peers) - groupWithMostPeers = matchingGroups[i] - } - } - - return groupWithMostPeers, nil + return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) } // SaveGroup object of the peers @@ -262,6 +217,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use return nil } + allGroup, err := account.GetGroupAll() + if err != nil { + return err + } + + if allGroup.ID == groupID { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if err = validateDeleteGroup(account, group, userId); err != nil { return err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 5d7094b6a35..cda3bc7482b 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -262,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string } claims := s.jwtClaimsExtractor.FromToken(token) // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(ctx, claims) + _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index ffa5b9a287c..91caa15128a 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - if !(user.HasAdminPower() || user.IsServiceUser) { - util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(account) + resp := toAccountResponse(accountID, settings) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(updatedAccount) + resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) util.WriteJSONObject(r.Context(), w, &resp) } // DeleteAccount is a HTTP DELETE handler to delete an account func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, emptyObject{}) } -func toAccountResponse(account *server.Account) *api.Account { - jwtAllowGroups := account.Settings.JWTAllowGroups +func toAccountResponse(accountID string, settings *server.Settings) *api.Account { + jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} } - settings := api.AccountSettings{ - PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), - PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, - GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, - JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, - JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, + apiSettings := api.AccountSettings{ + PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), + PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, + GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, + JwtGroupsEnabled: &settings.JWTGroupsEnabled, + JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, - RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, + RegularUsersViewBlocked: settings.RegularUsersViewBlocked, } - if account.Settings.Extra != nil { - settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled} + if settings.Extra != nil { + apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled} } return &api.Account{ - Id: account.Id, - Settings: settings, + Id: accountID, + Settings: apiSettings, } } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 45c7679e50f..cacb3d43010 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -23,8 +23,11 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return account, admin, nil + GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return account.Id, admin.Id, nil + }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + return account.Settings, nil }, UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index 74b0e1a55ad..13c2101a755 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index 897ae63dcef..8baea7b1538 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil + GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 428b4c164c0..ee0c63f2822 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 8bdd508bf77..e525cf2ee01 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { +func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { return &EventsHandler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { @@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil @@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) { accountID := "test_account" adminUser := server.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) - handler := initEventsTestData(accountID, adminUser, events...) + handler := initEventsTestData(accountID, events...) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 7f4d6dc7c72..19c916dd2e3 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -11,9 +11,9 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -43,14 +43,11 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return server.NewAdminUser(id), nil }, }, geolocationManager: geo, diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index af4d3116f4b..418228abfe6 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + return err + } + + user, err := l.accountManager.GetUserByID(r.Context(), userID) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c622d873af7..f369d1a0091 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gorilla/mux" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" @@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) + groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { groupsResponse := make([]*api.Group, 0, len(groups)) for _, group := range groups { - groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group)) } util.WriteJSONObject(r.Context(), w, groupsResponse) @@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { return } - eg, ok := account.Groups[groupID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) + existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - allGroup, err := account.GetGroupAll() + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) if err != nil { util.WriteError(r.Context(), err, w) return } + if allGroup.ID == groupID { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return @@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { ID: groupID, Name: req.Name, Peers: peers, - Issued: eg.Issued, - IntegrationReference: eg.IntegrationReference, + Issued: existingGroup.Issued, + IntegrationReference: existingGroup.IntegrationReference, + } + + if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) + util.WriteError(r.Context(), err, w) + return } - if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { - log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { @@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := account.GetGroupAll() - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if allGroup.ID == groupID { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) - return - } - - err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { @@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } + groupID := mux.Vars(r)["groupId"] + if len(groupID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) + return + } - switch r.Method { - case http.MethodGet: - groupID := mux.Vars(r)["groupId"] - if len(groupID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) - return - } - - group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } + group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) + } -func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, @@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { for _, pid := range group.Peers { _, ok := cache[pid] if !ok { - peer, ok := account.Peers[pid] + peer, ok := peersMap[pid] if !ok { continue } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index d5ed07c9ef3..7f3c81f1872 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/magiconair/properties/assert" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { +func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { @@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return nil }, GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - if groupID != "idofthegroup" { + groups := map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + } + + for _, group := range initGroups { + groups[group.ID] = group + } + + group, ok := groups[groupID] + if !ok { return nil, status.Errorf(status.NotFound, "not found") } - if groupID == "id-jwt-group" { - return &nbgroup.Group{ - ID: "id-jwt-group", - Name: "Default Group", - Issued: nbgroup.GroupIssuedJWT, - }, nil + + return group, nil + }, + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + if groupName == "All" { + return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil } - return &nbgroup.Group{ - ID: "idofthegroup", - Name: "Group", - Issued: nbgroup.GroupIssuedAPI, - }, nil + + return nil, fmt.Errorf("unknown group name") }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Peers: TestPeers, - Users: map[string]*server.User{ - user.Id: user, - }, - Groups: map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, - }, - }, user, nil + GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + return maps.Values(TestPeers), nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { @@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser, group) + p := initGroupTestData(group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index c6e00bb2d7f..e7a2bc2ae8a 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 28b080571e1..98c2e402d84 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -29,14 +28,6 @@ const ( testNSGroupAccountID = "test_id" ) -var testingNSAccount = &server.Account{ - Id: testNSGroupAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), - }, -} - var baseExistingNSGroup = &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: "super", @@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 9d8448d3dea..dfa9563e3b9 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] + targetUserID := vars["userId"] if len(userID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index b72f71468df..c28228a506e 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -77,8 +77,8 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testAccount, testAccount.Users[existingUserID], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { @@ -119,7 +119,7 @@ func initPATTestData() *PATHandler { return jwtclaims.AuthorizationClaims{ UserId: existingUserID, Domain: testDomain, - AccountId: testNSGroupAccountID, + AccountId: existingAccountID, } }), ), diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 5a2190d83fa..4fbbc3106d3 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -74,7 +74,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -96,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) if err != nil { util.WriteError(ctx, err, w) return @@ -130,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodDelete: - h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodPut: - h.updatePeer(r.Context(), account, user, peerID, w, r) - return - case http.MethodGet: - h.getPeer(r.Context(), account, peerID, user.Id, w) + case http.MethodGet, http.MethodPut: + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + if r.Method == http.MethodGet { + h.getPeer(r.Context(), account, peerID, userID, w) + } else { + h.updatePeer(r.Context(), account, userID, peerID, w, r) + } return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -159,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -179,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { + respBody := make([]*api.PeerBatch, 0, len(account.Peers)) + for _, peer := range account.Peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) @@ -214,7 +216,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -227,6 +229,18 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request return } + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + user, err := account.FindUser(userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + // If the user is regular user and does not own the peer // with the given peerID return an empty list if !user.HasAdminPower() && !user.IsServiceUser { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index dae264fff11..f933eee1497 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -13,16 +13,15 @@ import ( "time" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -70,7 +69,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() @@ -78,7 +80,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { policy := &server.Policy{ ID: "policy", - AccountID: claims.AccountId, + AccountID: accountID, Name: "policy", Enabled: true, Rules: []*server.PolicyRule{ @@ -100,7 +102,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { srvUser.IsServiceUser = true account := &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Peers: peersMap, Users: map[string]*server.User{ @@ -111,7 +113,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { Groups: map[string]*nbgroup.Group{ "group1": { ID: "group1", - AccountID: claims.AccountId, + AccountID: accountID, Name: "group1", Issued: "api", Peers: maps.Keys(peersMap), @@ -132,7 +134,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, } - return account, account.Users[claims.UserId], nil + return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { statuses := make(map[string]struct{}) @@ -279,9 +281,15 @@ func TestGetPeers(t *testing.T) { // hardcode this check for now as we only have two peers in this suite assert.Equal(t, len(respBody), 2) - assert.Equal(t, respBody[1].Connected, false) - got = respBody[0] + for _, peer := range respBody { + if peer.Id == testPeerID { + got = peer + } else { + assert.Equal(t, peer.Connected, false) + } + } + } else { got = &api.Peer{} err = json.Unmarshal(content, got) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 9622668f49f..225d7e1f30c 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/gorilla/mux" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/netbirdio/netbird/management/server" @@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) + listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - policies := []*api.Policy{} - for _, policy := range accountPolicies { - resp := toPolicyResponse(account, policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policies := make([]*api.Policy, 0, len(listPolicies)) + for _, policy := range listPolicies { + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) + _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, account, user, policyID) + h.savePolicy(w, r, accountID, userID, policyID) } // CreatePolicy handles policy creation request func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, account, user, "") + h.savePolicy(w, r, accountID, userID, "") } // savePolicy handles policy creation and update -func (h *Policies) savePolicy( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - policyID string, -) { +func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -127,6 +122,8 @@ func (h *Policies) savePolicy( return } + isUpdate := policyID != "" + if policyID == "" { policyID = xid.New().String() } @@ -141,8 +138,8 @@ func (h *Policies) savePolicy( pr := server.PolicyRule{ ID: policyID, // TODO: when policy can contain multiple rules, need refactor Name: rule.Name, - Destinations: groupMinimumsToStrings(account, rule.Destinations), - Sources: groupMinimumsToStrings(account, rule.Sources), + Destinations: rule.Destinations, + Sources: rule.Sources, Bidirectional: rule.Bidirectional, } @@ -207,15 +204,21 @@ func (h *Policies) savePolicy( } if req.SourcePostureChecks != nil { - policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) + policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toPolicyResponse(account, &policy) + resp := toPolicyResponse(allGroups, &policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -227,12 +230,11 @@ func (h *Policies) savePolicy( // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id vars := mux.Vars(r) policyID := vars["policyId"] @@ -241,7 +243,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -252,40 +254,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - vars := mux.Vars(r) - policyID := vars["policyId"] - if len(policyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) - return - } + vars := mux.Vars(r) + policyID := vars["policyId"] + if len(policyID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + return + } - policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } + policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } - resp := toPolicyResponse(account, policy) - if len(resp.Rules) == 0 { - util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) - return - } + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } - util.WriteJSONObject(r.Context(), w, resp) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) + resp := toPolicyResponse(allGroups, policy) + if len(resp.Rules) == 0 { + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + return } + + util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { +func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + cache := make(map[string]api.GroupMinimum) ap := &api.Policy{ Id: &policy.ID, @@ -306,16 +314,18 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic Protocol: api.PolicyRuleProtocol(r.Protocol), Action: api.PolicyRuleAction(r.Action), } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -325,13 +335,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic cache[gid] = minimum } } + for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { rule.Destinations = append(rule.Destinations, cachedMinimum) continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -345,28 +356,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic } return ap } - -func groupMinimumsToStrings(account *server.Account, gm []string) []string { - result := make([]string, 0, len(gm)) - for _, g := range gm { - if _, ok := account.Groups[g]; !ok { - continue - } - result = append(result, g) - } - return result -} - -func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) - for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } - } - - } - return result -} diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 06274fb072d..228ebcbceef 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + }, + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + user := server.NewAdminUser(userID) return &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Policies: []*server.Policy{ {ID: "id-existed"}, @@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Users: map[string]*server.User{ "test_user": user, }, - }, user, nil + }, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 059cb3b8055..1d020e9bcb7 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) + listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - postureChecks := []*api.PostureCheck{} - for _, postureCheck := range accountPostureChecks { + postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks)) + for _, postureCheck := range listPostureChecks { postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } @@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - postureChecksIdx := -1 - for i, postureCheck := range account.PostureChecks { - if postureCheck.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) + _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, postureChecksID) + p.savePostureChecks(w, r, accountID, userID, postureChecksID) } // CreatePostureCheck handles posture check creation request func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, "") + p.savePostureChecks(w, r, accountID, userID, "") } // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - postureChecksID string, -) { +func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate @@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks( return } - if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 974edafde2f..02f0f0d8308 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - PostureChecks: postureChecks, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, geolocationManager: &geolocation.Geolocation{}, diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 18c347334ed..0932e64455e 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -117,15 +117,9 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } - // Do not allow non-Linux peers - if peer := account.GetPeer(peerId); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) - return - } - } - - newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute, + ) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +162,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -181,7 +175,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -204,14 +198,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { peerID = *req.Peer } - // do not allow non Linux peers - if peer := account.GetPeer(peerID); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) - return - } - } - newRoute := &route.Route{ ID: route.ID(routeID), NetID: route.NetID(req.NetworkId), @@ -247,7 +233,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -265,7 +251,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -277,7 +263,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -289,7 +275,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -301,7 +287,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 40075eb9d8b..2c367cac399 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -112,6 +112,12 @@ func initRoutesTestData() *RoutesHandler { if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) } + if peerID != "" { + if peerID == nonLinuxExistingPeerID { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + return &route.Route{ ID: existingRouteID, NetID: netID, @@ -131,6 +137,11 @@ func initRoutesTestData() *RoutesHandler { if r.Peer == notFoundPeerID { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) } + + if r.Peer == nonLinuxExistingPeerID { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + return nil }, DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { @@ -139,8 +150,9 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + //return testingAccount, testingAccount.Users["test_user"], nil + return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8ee7dfabaa7..8514f0b556b 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, user.Id, ephemeral) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) return @@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index bfa0ec008d9..2d15287af25 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" @@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - SetupKeys: map[string]*server.SetupKey{ - defaultKey.Key: defaultKey, - }, - Groups: map[string]*nbgroup.Group{ - "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 2c2aed84284..6e151a0da3a 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - existingUser, ok := account.Users[userID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) return } @@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ - Id: userID, + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, Blocked: req.IsBlocked, @@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index a78ac3a4e04..f3d989da19f 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return usersTestAccount, usersTestAccount.Users[claims.UserId], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return usersTestAccount.Id, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return usersTestAccount.Users[id], nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 49532525279..df12ec1c437 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -23,10 +23,11 @@ import ( type MockAccountManager struct { GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -48,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -79,7 +80,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string @@ -105,6 +106,9 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { @@ -190,16 +194,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountByUserOrAccountID( - ctx context.Context, userId, accountId, domain string, -) (*server.Account, error) { - if am.GetAccountByUserOrAccountIdFunc != nil { - return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { + if am.GetAccountIDByUserOrAccountIdFunc != nil { + return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) } - return nil, status.Errorf( + return "", status.Errorf( codes.Unimplemented, - "method GetAccountByUserOrAccountID is not implemented", + "method GetAccountIDByUserOrAccountID is not implemented", ) } @@ -377,9 +379,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) } return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } @@ -601,14 +603,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, - error, -) { - if am.GetAccountFromTokenFunc != nil { - return am.GetAccountFromTokenFunc(ctx, claims) +// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface +func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + if am.GetAccountIDFromTokenFunc != nil { + return am.GetAccountIDFromTokenFunc(ctx, claims) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented") } func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { @@ -802,3 +802,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") } + +// GetAccountByID mocks GetAccountByID of the AccountManager interface +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { + if am.GetAccountByIDFunc != nil { + return am.GetAccountByIDFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented") +} + +// GetUserByID mocks GetUserByID of the AccountManager interface +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { + if am.GetUserByIDFunc != nil { + return am.GetUserByIDFunc(ctx, id) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") +} + +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + if am.GetAccountSettingsFunc != nil { + return am.GetAccountSettingsFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") +} + +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { + if am.GetAccountFunc != nil { + return am.GetAccountFunc(ctx, accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 636f7cfee43..0eb5d9ae4a4 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups") - } - - nsGroup, found := account.NameServerGroups[nsGroupID] - if found { - return nsGroup.Copy(), nil + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) - for _, item := range account.NameServerGroups { - nsGroups = append(nsGroups, item.Copy()) - } - - return nsGroups, nil + return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4b2ec66c68d..d329e04bc46 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -251,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { Action: PolicyTrafficActionAccept, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -299,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return diff --git a/management/server/policy.go b/management/server/policy.go index aaf9b6e72d0..5d07ba8f8a0 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,6 +3,7 @@ package server import ( "context" _ "embed" + "slices" "strconv" "strings" @@ -314,34 +315,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - for _, policy := range account.Policies { - if policy.ID == policyID { - return policy, nil - } - } - - return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID) + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -350,7 +337,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - exists := am.savePolicy(account, policy) + if err = am.savePolicy(account, policy, isUpdate); err != nil { + return err + } account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { @@ -358,7 +347,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } action := activity.PolicyAdded - if exists { + if isUpdate { action = activity.PolicyUpdated } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) @@ -397,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po // ListPolicies from the store func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies") + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - return account.Policies, nil + return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { @@ -434,18 +415,34 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) return policy, nil } -func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists bool) { - for i, p := range account.Policies { - if p.ID == policy.ID { - account.Policies[i] = policy - exists = true - break - } +// savePolicy saves or updates a policy in the given account. +// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { + for index, rule := range policyToSave.Rules { + rule.Sources = filterValidGroupIDs(account, rule.Sources) + rule.Destinations = filterValidGroupIDs(account, rule.Destinations) + policyToSave.Rules[index] = rule } - if !exists { - account.Policies = append(account.Policies, policy) + + if policyToSave.SourcePostureChecks != nil { + policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) } - return + + if isUpdate { + policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) + if policyIdx < 0 { + return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + } + + // Update the existing policy + account.Policies[policyIdx] = policyToSave + return nil + } + + // Add the new policy to the account + account.Policies = append(account.Policies, policyToSave) + + return nil } func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { @@ -560,3 +557,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } return nil } + +// filterValidPostureChecks filters and returns the posture check IDs from the given list +// that are valid within the provided account. +func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { + result := make([]string, 0, len(postureChecksIds)) + for _, id := range postureChecksIds { + for _, postureCheck := range account.PostureChecks { + if id == postureCheck.ID { + result = append(result, id) + continue + } + } + } + return result +} + +// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. +func filterValidGroupIDs(account *Account, groupIDs []string) []string { + result := make([]string, 0, len(groupIDs)) + for _, groupID := range groupIDs { + if _, exists := account.Groups[groupID]; exists { + result = append(result, groupID) + } + } + return result +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 4180550e6a9..9a4b679cef5 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -15,30 +15,16 @@ const ( ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - for _, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks, nil - } - } - - return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID) + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) } func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { @@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun } func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - return account.PostureChecks, nil + return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { diff --git a/management/server/route.go b/management/server/route.go index 064f3c10596..6c1c8b1b3c0 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -17,29 +17,16 @@ import ( // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - wantedRoute, found := account.Routes[routeID] - if found { - return wantedRoute, nil - } - - return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) + return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. @@ -134,6 +121,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } + // Do not allow non-Linux peers + if peer := account.GetPeer(peerID); peer != nil { + if peer.Meta.GoOS != "linux" { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -234,6 +228,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + // Do not allow non-Linux peers + if peer := account.GetPeer(routeToSave.Peer); peer != nil { + if peer.Meta.GoOS != "linux" { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -311,29 +312,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - routes := make([]*route.Route, 0, len(account.Routes)) - for _, item := range account.Routes { - routes = append(routes, item) - } - - return routes, nil + return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { diff --git a/management/server/route_test.go b/management/server/route_test.go index 506bfb0a830..4533c6b7e5c 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1205,7 +1205,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 859f1b0b918..9521e22d339 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -330,26 +330,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err } - keys := make([]*SetupKey, 0, len(account.SetupKeys)) - for _, key := range account.SetupKeys { + keys := make([]*SetupKey, 0, len(setupKeys)) + for _, key := range setupKeys { var k *SetupKey - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() { k = key.HiddenCopy(999) } else { k = key.Copy() @@ -362,44 +360,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") } - var foundKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyID { - foundKey = key.Copy() - break - } - } - if foundKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if err != nil { + return nil, err } // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) - if foundKey.UpdatedAt.IsZero() { - foundKey.UpdatedAt = foundKey.CreatedAt + if setupKey.UpdatedAt.IsZero() { + setupKey.UpdatedAt = setupKey.CreatedAt } - if !(user.HasAdminPower() || user.IsServiceUser) { - foundKey = foundKey.HiddenCopy(999) + if !user.IsAdminOrServiceUser() { + setupKey = setupKey.HiddenCopy(999) } - return foundKey, nil + return setupKey, nil } func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 8fa5f9d0588..85c68ef4488 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -36,6 +36,7 @@ const ( idQueryCondition = "id = ?" keyQueryCondition = "key = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" + accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" ) @@ -399,20 +400,30 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { } func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { - var account Account + accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if err != nil { + return nil, err + } + + // TODO: rework to not call GetAccount + return s.GetAccount(ctx, accountID) +} - result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory) +func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + var accountID string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory, + ).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return "", status.Errorf(status.Internal, "issue getting account from store") } - // TODO: rework to not call GetAccount - return s.GetAccount(ctx, account.Id) + return accountID, nil } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { @@ -478,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { var user User result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&user, idQueryCondition, userID) + Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -491,7 +502,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Find(&groups, idQueryCondition, accountID) + result := s.db.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -661,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { - var user User var accountID string - result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -1028,3 +1038,152 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { func (s *SqlStore) GetDB() *gorm.DB { return s.db } + +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { + var accountDNSSettings AccountDNSSettings + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + First(&accountDNSSettings, idQueryCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "dns settings not found") + } + return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + } + return &accountDNSSettings.DNSSettings, nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { + var accountID string + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Select("id").First(&accountID, idQueryCondition, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return false, nil + } + return false, result.Error + } + + return accountID != "", nil +} + +// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. +func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { + var account Account + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + Where(idQueryCondition, accountID).First(&account) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", "", status.Errorf(status.NotFound, "account not found") + } + return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) + } + + return account.Domain, account.DomainCategory, nil +} + +// GetGroupByID retrieves a group by ID and account ID. +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { + return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +} + +// GetGroupByName retrieves a group by name and account ID. +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + } + return &group, nil +} + +// GetAccountPolicies retrieves policies for an account. +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { + return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) +} + +// GetPolicyByID retrieves a policy by its ID and account ID. +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { + return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) +} + +// GetAccountPostureChecks retrieves posture checks for an account. +func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { + return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetPostureChecksByID retrieves posture checks by their ID and account ID. +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { + return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) +} + +// GetAccountRoutes retrieves network routes for an account. +func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { + return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetRouteByID retrieves a route by its ID and account ID. +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { + return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) +} + +// GetAccountSetupKeys retrieves setup keys for an account. +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { + return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetSetupKeyByID retrieves a setup key by its ID and account ID. +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { + return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +} + +// GetAccountNameServerGroups retrieves name server groups for an account. +func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { + return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetNameServerGroupByID retrieves a name server group by its ID and account ID. +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { + return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) +} + +// getRecords retrieves records from the database based on the account ID. +func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { + var record []T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) + } + + return record, nil +} + +// getRecordByID retrieves a record by its ID and account ID from the database. +func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { + var record T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&record, accountAndIDQueryCondition, accountID, recordID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "%s not found", recordType) + } + return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) + } + return &record, nil +} diff --git a/management/server/store.go b/management/server/store.go index 84b3b140c6c..f34a73c2d41 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -39,53 +40,81 @@ const ( type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) - DeleteAccount(ctx context.Context, account *Account) error + AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) + GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(peerKey string) (string, error) + GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) - GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) + GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) + SaveAccount(ctx context.Context, account *Account) error + DeleteAccount(ctx context.Context, account *Account) error + GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) - GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - SaveAccount(ctx context.Context, account *Account) error SaveUsers(accountID string, users map[string]*User) error - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + + GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + + GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(accountID string, peer *nbpeer.Peer) error + + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + + GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + + GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) + GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + IncrementNetworkSerial(ctx context.Context, accountId string) error + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error + // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error - SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + // Close should close the store persisting all unsaved data. Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) - GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error - IncrementNetworkSerial(ctx context.Context, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error } diff --git a/management/server/user.go b/management/server/user.go index 9e60bb94ba4..6d01561c6cc 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -94,6 +94,11 @@ func (u *User) HasAdminPower() bool { return u.Role == UserRoleAdmin || u.Role == UserRoleOwner } +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups @@ -357,39 +362,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { + return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +} + // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccount(ctx, account.Id) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return nil, fmt.Errorf("failed to get an account from store %v", err) - } - - user, ok := account.Users[claims.UserId] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + return nil, err } - // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC + // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin) - err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta) } return user, nil @@ -642,63 +643,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string // GetPAT returns a specific PAT from a user func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found: %s", err) - } - - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser") + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - pat := targetUser.PATs[tokenID] - if pat == nil { - return nil, status.Errorf(status.NotFound, "PAT not found") + for _, pat := range targetUser.PATsG { + if pat.ID == tokenID { + return pat.Copy(), nil + } } - return pat, nil + return nil, status.Errorf(status.NotFound, "PAT not found") } // GetAllPATs returns all PATs for a user func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found: %s", err) - } - - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - var pats []*PersonalAccessToken - for _, pat := range targetUser.PATs { - pats = append(pats, pat) + pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) + for _, pat := range targetUser.PATsG { + pats = append(pats, pat.Copy()) } return pats, nil diff --git a/management/server/user_test.go b/management/server/user_test.go index 2720602765a..e394ef840db 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -199,7 +199,8 @@ func TestUser_GetPAT(t *testing.T) { defer store.Close(context.Background()) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ - Id: mockUserID, + Id: mockUserID, + AccountID: mockAccountID, PATs: map[string]*PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, @@ -231,7 +232,8 @@ func TestUser_GetAllPATs(t *testing.T) { defer store.Close(context.Background()) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ - Id: mockUserID, + Id: mockUserID, + AccountID: mockAccountID, PATs: map[string]*PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, @@ -796,7 +798,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") + assert.NoError(t, err) + + acc, err := am.Store.GetAccount(context.Background(), accID) assert.NoError(t, err) for _, id := range tc.expectedDeleted { From 58ff7ab797fcde081b3a0a802487f13a4ab4945a Mon Sep 17 00:00:00 2001 From: adasauce <60991921+adasauce@users.noreply.github.com> Date: Fri, 27 Sep 2024 16:21:34 -0300 Subject: [PATCH 07/81] [management] improve zitadel idp error response detail by decoding errors (#2634) * [management] improve zitadel idp error response detail by decoding errors * [management] extend readZitadelError to be used for requestJWTToken more generically parse the error returned by zitadel. * fix lint --------- Co-authored-by: bcmmbaga --- management/server/idp/zitadel.go | 49 +++++++++++++++++++++++++-- management/server/idp/zitadel_test.go | 10 +++--- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 729b49733d3..9d7626844a5 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -2,10 +2,12 @@ package idp import ( "context" + "errors" "fmt" "io" "net/http" "net/url" + "slices" "strings" "sync" "time" @@ -97,6 +99,42 @@ type zitadelUserResponse struct { PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"` } +// readZitadelError parses errors returned by the zitadel APIs from a response. +func readZitadelError(body io.ReadCloser) error { + bodyBytes, err := io.ReadAll(body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + helper := JsonParser{} + var target map[string]interface{} + err = helper.Unmarshal(bodyBytes, &target) + if err != nil { + return fmt.Errorf("error unparsable body: %s", string(bodyBytes)) + } + + // ensure keys are ordered for consistent logging behaviour. + errorKeys := make([]string, 0, len(target)) + for k := range target { + errorKeys = append(errorKeys, k) + } + slices.Sort(errorKeys) + + var errsOut []string + for _, k := range errorKeys { + if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded { + continue + } + errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k])) + } + + if len(errsOut) == 0 { + return errors.New("unknown error") + } + + return errors.New(strings.Join(errsOut, " ")) +} + // NewZitadelManager creates a new instance of the ZitadelManager. func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() @@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode) + zErr := readZitadelError(resp.Body) + return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr) } return resp, nil @@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) @@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 6bc612e78c5..722f94fe0b6 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) { } func TestZitadelRequestJWTToken(t *testing.T) { - type requestJWTTokenTest struct { name string inputCode int @@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) { requestJWTTokenTestCase2 := requestJWTTokenTest{ name: "Request Bad Status Code", inputCode: 400, - inputRespBody: "{}", + inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"), expectedToken: "", } for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} { t.Run(testCase.name, func(t *testing.T) { - jwtReqClient := mockHTTPClient{ resBody: testCase.inputRespBody, code: testCase.inputCode, @@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) { } parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{ name: "Parse Bad json JWT Body", - inputRespBody: "", + inputRespBody: "{}", helper: JsonParser{}, expectedToken: "", expectedExpiresIn: 0, @@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) { inputCode: 400, inputResBody: "{}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"), expectedCode: 200, expectedToken: "", } From 52ae693c9e5eff72082d4330eab6723251559546 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:22:47 +0200 Subject: [PATCH 08/81] [signal] add context to signal-dispatcher (#2662) --- client/cmd/testutil_test.go | 2 +- client/internal/engine_test.go | 2 +- client/server/server_test.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- signal/client/client_test.go | 2 +- signal/cmd/run.go | 2 +- signal/server/signal.go | 4 ++-- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 780cc8b04ad..f0dc8bf214c 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - srv, err := sig.NewServer(otel.Meter("")) + srv, err := sig.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) sigProto.RegisterSignalExchangeServer(s, srv) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index f3056638013..95aadf14186 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1056,7 +1056,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) diff --git a/client/server/server_test.go b/client/server/server_test.go index 795060fabc6..9b18df4d37f 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -160,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) diff --git a/go.mod b/go.mod index 12709e50dcc..cf3b610bd9e 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 2355f6f0c8c..089629cdf60 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080 h1:mXJkoWLdqJTlkQ7DgQ536kcXHXIdUPeagkN8i4eFDdg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240820130728-bc0683599080/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 h1:NZm4JvvjKuEh3p7daHUy3rWKhKsnUzzYpGv1qT4dYLc= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 2525493b4dd..f7d4ebc5007 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) { panic(err) } s := grpc.NewServer() - srv, err := server.NewServer(otel.Meter("")) + srv, err := server.NewServer(context.Background(), otel.Meter("")) if err != nil { panic(err) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 0bdc62eadeb..1bb2f1d0c14 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -102,7 +102,7 @@ var ( } }() - srv, err := server.NewServer(metricsServer.Meter) + srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) if err != nil { return fmt.Errorf("creating signal server: %v", err) } diff --git a/signal/server/signal.go b/signal/server/signal.go index b268aa3fcbe..c020c56042b 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -47,13 +47,13 @@ type Server struct { } // NewServer creates a new Signal server -func NewServer(meter metric.Meter) (*Server, error) { +func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { appMetrics, err := metrics.NewAppMetrics(meter) if err != nil { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher() + dispatcher, err := dispatcher.NewDispatcher(ctx) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } From cfbcf507fb0ae039c270af48822679a754b8c530 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:23:34 +0200 Subject: [PATCH 09/81] propagate meter (#2668) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index cf3b610bd9e..edee0ede4db 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 089629cdf60..2160fa1f8b7 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086 h1:NZm4JvvjKuEh3p7daHUy3rWKhKsnUzzYpGv1qT4dYLc= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240928205912-5569c4c5e086/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index c020c56042b..386ce72389f 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -53,7 +53,7 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { return nil, fmt.Errorf("creating app metrics: %v", err) } - dispatcher, err := dispatcher.NewDispatcher(ctx) + dispatcher, err := dispatcher.NewDispatcher(ctx, meter) if err != nil { return nil, fmt.Errorf("creating dispatcher: %v", err) } From 3dca6099d4f1a32c2e2ddbabe88a49d786fb3c41 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 30 Sep 2024 10:34:57 +0200 Subject: [PATCH 10/81] Fix ebpf close function (#2672) --- client/internal/wgproxy/ebpf/proxy.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 4bd4bfff624..27ede3ef1d0 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -81,8 +81,7 @@ func (p *WGEBPFProxy) Listen() error { conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { - cErr := p.Free() - if cErr != nil { + if cErr := p.Free(); cErr != nil { log.Errorf("Failed to close the wgproxy: %s", cErr) } return err @@ -122,8 +121,10 @@ func (p *WGEBPFProxy) Free() error { p.ctxCancel() var result *multierror.Error - if err := p.conn.Close(); err != nil { - result = multierror.Append(result, err) + if p.conn != nil { // p.conn will be nil if we have failed to listen + if err := p.conn.Close(); err != nil { + result = multierror.Append(result, err) + } } if err := p.ebpfManager.FreeWGProxy(); err != nil { From 2fd60b2cb46a77f16b5e1e1f72a1a09f03f0ecbe Mon Sep 17 00:00:00 2001 From: Gianluca Boiano <491117+M0Rf30@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:43:34 +0200 Subject: [PATCH 11/81] Specify goreleaser version and update to 2 (#2673) --- .github/workflows/release.yml | 72 +++++++++++++---------------------- .goreleaser.yaml | 32 ++++++++-------- .goreleaser_ui.yaml | 9 +++-- .goreleaser_ui_darwin.yaml | 6 ++- CONTRIBUTING.md | 2 +- 5 files changed, 52 insertions(+), 69 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5f423f1c9ca..162e488c371 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,15 +3,14 @@ name: Release on: push: tags: - - 'v*' + - "v*" branches: - main pull_request: - env: SIGN_PIPE_VER: "v0.0.14" - GORELEASER_VER: "v1.14.1" + GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -34,19 +33,16 @@ jobs: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout + - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.23" cache: false - - - name: Cache Go modules + - name: Cache Go modules uses: actions/cache@v4 with: path: | @@ -55,20 +51,15 @@ jobs: key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go-releaser- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Set up QEMU + - name: Set up QEMU uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - - - name: Login to Docker hub + - name: Login to Docker hub if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: @@ -85,35 +76,31 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --rm-dist ${{ env.flags }} + args: release --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - - - name: upload non tags for debug purposes + - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 3 - - - name: upload linux packages + - name: upload linux packages uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 3 - - - name: upload windows packages + - name: upload windows packages uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 3 - - - name: upload macos packages + - name: upload macos packages uses: actions/upload-artifact@v4 with: name: macos-packages @@ -145,7 +132,7 @@ jobs: - name: Cache Go modules uses: actions/cache@v4 with: - path: | + path: | ~/go/pkg/mod ~/.cache/go-build key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }} @@ -169,7 +156,7 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} @@ -187,19 +174,16 @@ jobs: steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout + - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go + - name: Set up Go uses: actions/setup-go@v5 with: go-version: "1.23" cache: false - - - name: Cache Go modules + - name: Cache Go modules uses: actions/cache@v4 with: path: | @@ -208,23 +192,19 @@ jobs: key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-ui-go-releaser-darwin- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Run GoReleaser + - name: Run GoReleaser id: goreleaser uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: upload non tags for debug purposes + - name: upload non tags for debug purposes uses: actions/upload-artifact@v4 with: name: release-ui-darwin @@ -233,7 +213,7 @@ jobs: trigger_signer: runs-on: ubuntu-latest - needs: [release,release_ui,release_ui_darwin] + needs: [release, release_ui, release_ui_darwin] if: startsWith(github.ref, 'refs/tags/') steps: - name: Trigger binaries sign pipelines diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 068864d6ee7..cf2ce4f4f0d 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird builds: - id: netbird @@ -22,7 +24,7 @@ builds: goarch: 386 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -42,19 +44,19 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc - id: netbird-mgmt dir: management env: - - CGO_ENABLED=1 - - >- - {{- if eq .Runtime.Goos "linux" }} - {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} - {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} - {{- end }} + - CGO_ENABLED=1 + - >- + {{- if eq .Runtime.Goos "linux" }} + {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} + {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} + {{- end }} binary: netbird-mgmt goos: - linux @@ -64,7 +66,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-signal dir: signal @@ -78,7 +80,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-relay dir: relay @@ -92,7 +94,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - builds: @@ -100,7 +102,6 @@ archives: - netbird-static nfpms: - - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ @@ -416,10 +417,9 @@ docker_manifests: - netbirdio/management:{{ .Version }}-debug-amd64 brews: - - - ids: + - ids: - default - tap: + repository: owner: netbirdio name: homebrew-tap token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" @@ -436,7 +436,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index fd92b5328d8..06577f4e3c3 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui @@ -11,7 +13,7 @@ builds: - amd64 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-ui-windows dir: client/ui @@ -26,7 +28,7 @@ builds: ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -H windowsgui - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - id: linux-arch @@ -39,7 +41,6 @@ archives: - netbird-ui-windows nfpms: - - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ @@ -77,7 +78,7 @@ nfpms: uploads: - name: debian ids: - - netbird-ui-deb + - netbird-ui-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index 2c3afa91bb6..bccb7f4717a 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui-darwin @@ -17,7 +19,7 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -28,4 +30,4 @@ archives: checksum: name_template: "{{ .ProjectName }}_darwin_checksums.txt" changelog: - skip: true \ No newline at end of file + disable: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 492aa5c2ed8..c82cfc763f7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR: **Goreleaser** ```shell -goreleaser --snapshot --rm-dist +goreleaser build --snapshot --clean ``` **golangci-lint** ```shell From e27f85b317a97721921933659a80c8be35c785e1 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 30 Sep 2024 20:07:21 +0200 Subject: [PATCH 12/81] Update docker creds (#2677) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 162e488c371..7af6d3e4d94 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -63,7 +63,7 @@ jobs: if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: - username: netbirdio + username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_TOKEN }} - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu From 16179db599ef6fb42e709597bc260101dfa7cd74 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 30 Sep 2024 22:18:10 +0200 Subject: [PATCH 13/81] [management] Propagate metrics (#2667) --- go.mod | 2 +- go.sum | 4 ++-- management/server/http/handler.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index edee0ede4db..c29ba076347 100644 --- a/go.mod +++ b/go.mod @@ -59,7 +59,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e + github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 2160fa1f8b7..1f6cbb785be 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk= diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ef94f22b9e1..3f8a8554d07 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa AuthCfg: authCfg, } - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } From 24c0aaa745bc2ac46bdcf1f855834306a886db95 Mon Sep 17 00:00:00 2001 From: Simen <97337442+simen64@users.noreply.github.com> Date: Tue, 1 Oct 2024 13:32:58 +0200 Subject: [PATCH 14/81] Install sh alpine fixes (#2678) * Made changes to the peer install script that makes it work on alpine linux without changes * fix small oversight with doas fix * use try catch approach when curling binaries --- release_files/install.sh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index d6aabebd8b2..5dd0f67bb8e 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -21,6 +21,8 @@ SUDO="" if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then SUDO="sudo" +elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then + SUDO="doas" fi if [ -z ${NETBIRD_RELEASE+x} ]; then @@ -68,7 +70,7 @@ download_release_binary() { if [ -n "$GITHUB_TOKEN" ]; then cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL" else - cd /tmp && curl -LO "$DOWNLOAD_URL" + cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL" fi @@ -316,7 +318,7 @@ install_netbird() { } version_greater_equal() { - printf '%s\n%s\n' "$2" "$1" | sort -V -C + printf '%s\n%s\n' "$2" "$1" | sort -V -c } is_bin_package_manager() { From ee0ea86a0a9394b2632ed2be3149d45c04baca67 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 1 Oct 2024 16:22:18 +0200 Subject: [PATCH 15/81] [relay-client] Fix Relay disconnection handling (#2680) * Fix Relay disconnection handling If has an active P2P connection meanwhile the Relay connection broken with the server then we removed the WireGuard peer configuration. * Change logs --- client/internal/peer/conn.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ea6d892b9f6..baff1372a16 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -586,13 +586,17 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - if conn.wgProxyRelay != nil { - log.Debugf("relayed connection is closed, clean up WireGuard config") + log.Debugf("relay connection is disconnected") + + if conn.currentConnPriority == connPriorityRelay { + log.Debugf("clean up WireGuard config") err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } + } + if conn.wgProxyRelay != nil { conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil From 5932298ce03ccda417cbf954020665fdc096baaa Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 11:48:09 +0200 Subject: [PATCH 16/81] Add log setting to Caddy container (#2684) This avoids full disk on busy systems --- infrastructure_files/getting-started-with-zitadel.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index c0275536baa..2c5c35d5302 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -793,6 +793,11 @@ services: volumes: - netbird_caddy_data:/data - ./Caddyfile:/etc/caddy/Caddyfile + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # UI dashboard dashboard: image: netbirdio/dashboard:latest From a3a479429eb13dc53b9d9dd7bfb1b0710c5055c0 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 11:48:42 +0200 Subject: [PATCH 17/81] Use the pkgs to get the latest version (#2682) * Use the pkgs to get the latest version * disable fail fast --- .github/workflows/install-script-test.yml | 1 + release_files/install.sh | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml index 04c222e873f..22d002a4833 100644 --- a/.github/workflows/install-script-test.yml +++ b/.github/workflows/install-script-test.yml @@ -13,6 +13,7 @@ concurrency: jobs: test-install-script: strategy: + fail-fast: false max-parallel: 2 matrix: os: [ubuntu-latest, macos-latest] diff --git a/release_files/install.sh b/release_files/install.sh index 5dd0f67bb8e..b7a6c08f9a7 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -33,14 +33,16 @@ get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then local TAG="latest" + local URL="https://pkgs.netbird.io/releases/latest" else local TAG="tags/${RELEASE}" + local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' else - curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' fi } From ff7863785f81c64ce0570b28950f806b75800c6a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 2 Oct 2024 14:41:00 +0300 Subject: [PATCH 18/81] [management, client] Add access control support to network routes (#2100) --- .github/workflows/golangci-lint.yml | 2 +- client/firewall/iface.go | 4 +- client/firewall/iptables/acl_linux.go | 174 +-- client/firewall/iptables/manager_linux.go | 64 +- .../firewall/iptables/manager_linux_test.go | 54 +- client/firewall/iptables/router_linux.go | 536 +++++---- client/firewall/iptables/router_linux_test.go | 268 +++-- client/firewall/manager/firewall.go | 125 +- client/firewall/manager/firewall_test.go | 192 ++++ client/firewall/manager/routerpair.go | 16 +- client/firewall/nftables/acl_linux.go | 549 +-------- client/firewall/nftables/manager_linux.go | 121 +- .../firewall/nftables/manager_linux_test.go | 76 +- client/firewall/nftables/route_linux.go | 431 ------- client/firewall/nftables/router_linux.go | 798 +++++++++++++ client/firewall/nftables/router_linux_test.go | 605 ++++++++-- client/firewall/test/cases_linux.go | 20 +- client/firewall/uspfilter/uspfilter.go | 42 +- client/firewall/uspfilter/uspfilter_test.go | 20 +- client/internal/acl/id/id.go | 25 + client/internal/acl/manager.go | 255 +++-- client/internal/acl/manager_test.go | 170 +-- client/internal/engine.go | 9 +- client/internal/routemanager/dynamic/route.go | 2 +- client/internal/routemanager/manager.go | 6 +- .../routemanager/refcounter/refcounter.go | 199 +++- .../internal/routemanager/refcounter/types.go | 6 +- .../routemanager/server_nonandroid.go | 16 +- client/internal/routemanager/static/route.go | 2 +- .../routemanager/systemops/systemops.go | 2 +- .../systemops/systemops_generic.go | 4 +- management/proto/management.pb.go | 1005 +++++++++++------ management/proto/management.proto | 84 +- management/server/account.go | 4 +- management/server/account_test.go | 7 +- management/server/grpcserver.go | 4 + management/server/http/api/openapi.yml | 30 +- management/server/http/api/types.gen.go | 30 +- management/server/http/policies_handler.go | 33 +- management/server/http/routes_handler.go | 16 +- management/server/http/routes_handler_test.go | 48 +- management/server/mock_server/account_mock.go | 8 +- management/server/network.go | 13 +- management/server/peer_test.go | 7 +- management/server/policy.go | 48 +- management/server/route.go | 292 ++++- management/server/route_test.go | 536 +++++++-- route/route.go | 5 +- 48 files changed, 4601 insertions(+), 2362 deletions(-) create mode 100644 client/firewall/manager/firewall_test.go delete mode 100644 client/firewall/nftables/route_linux.go create mode 100644 client/firewall/nftables/router_linux.go create mode 100644 client/internal/acl/id/id.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 8b713684130..2d743f79071 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable, + ignore_words_list: erro,clienta,hastable,iif skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/client/firewall/iface.go b/client/firewall/iface.go index 882daef7514..d0b5209c040 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,6 +1,8 @@ package firewall -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/iface" +) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index b77cc8f4346..c6a96a876cd 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -19,24 +19,22 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT" - - postRoutingMark = "0x000007e4" ) type aclManager struct { - iptablesClient *iptables.IPTables - wgIface iFaceMapper - routeingFwChainName string + iptablesClient *iptables.IPTables + wgIface iFaceMapper + routingFwChainName string entries map[string][][]string ipsetStore *ipsetStore } -func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) { +func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { m := &aclManager{ - iptablesClient: iptablesClient, - wgIface: wgIface, - routeingFwChainName: routeingFwChainName, + iptablesClient: iptablesClient, + wgIface: wgIface, + routingFwChainName: routingFwChainName, entries: make(map[string][][]string), ipsetStore: newIpsetStore(), @@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route return m, nil } -func (m *aclManager) AddFiltering( +func (m *aclManager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering( return nil, fmt.Errorf("rule already exists") } - if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil { + if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { return nil, err } @@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering( chain: chain, } - if !shouldAddToPrerouting(protocol, dPort, direction) { - return []firewall.Rule{rule}, nil - } - - rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip) - if err != nil { - return []firewall.Rule{rule}, err - } - return []firewall.Rule{rule, rulePrerouting}, nil + return []firewall.Rule{rule}, nil } -// DeleteRule from the firewall by rule definition -func (m *aclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") } - if r.chain == "PREROUTING" { - goto DELETERULE - } - if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { @@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error { } } -DELETERULE: - var table string - if r.chain == "PREROUTING" { - table = "mangle" - } else { - table = "filter" - } - err := m.iptablesClient.Delete(table, r.chain, r.specs...) + err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) if err != nil { log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) } @@ -203,44 +182,6 @@ func (m *aclManager) Reset() error { return m.cleanChains() } -func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) { - var src []string - if ipsetName != "" { - src = []string{"-m", "set", "--set", ipsetName, "src"} - } else { - src = []string{"-s", ip.String()} - } - specs := []string{ - "-d", m.wgIface.Address().IP.String(), - "-p", protocol, - "--dport", port, - "-j", "MARK", "--set-mark", postRoutingMark, - } - - specs = append(src, specs...) - - ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...) - if err != nil { - return nil, fmt.Errorf("failed to check rule: %w", err) - } - if ok { - return nil, fmt.Errorf("rule already exists") - } - - if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil { - return nil, err - } - - rule := &Rule{ - ruleID: uuid.New().String(), - specs: specs, - ipsetName: ipsetName, - ip: ip.String(), - chain: "PREROUTING", - } - return rule, nil -} - // todo write less destructive cleanup mechanism func (m *aclManager) cleanChains() error { ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) @@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error { } } - ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to list chains: %s", err) - return err - } - if ok { - for _, rule := range m.entries["PREROUTING"] { - err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) - if err != nil { - log.Errorf("failed to delete rule: %v, %s", rule, err) - } - } - err = m.iptablesClient.ClearChain("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to clear %s chain: %s", "PREROUTING", err) - return err - } - } - for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -338,17 +260,9 @@ func (m *aclManager) createDefaultChains() error { for chainName, rules := range m.entries { for _, rule := range rules { - if chainName == "FORWARD" { - // position 2 because we add it after router's, jump rule - if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } - } else { - if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } + if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { + log.Debugf("failed to create input chain jump rule: %s", err) + return err } } } @@ -356,40 +270,29 @@ func (m *aclManager) createDefaultChains() error { return nil } -func (m *aclManager) seedInitialEntries() { - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) +// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed. +// We want to make sure our traffic is not dropped by existing rules. - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules}) +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. - m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) +// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. +func (m *aclManager) seedInitialEntries() { - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + established := getConntrackEstablished() - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules}) + m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) + m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) + m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) - m.appendToEntries("FORWARD", - []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", - []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName}) - - m.appendToEntries("PREROUTING", - []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) + m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } func (m *aclManager) appendToEntries(chainName string, spec []string) { @@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string { return ipsetName } } - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil { - return false - } - return true -} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2d231ec456d..fae41d9c5a9 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/coreos/go-iptables/iptables" @@ -21,7 +22,7 @@ type Manager struct { ipv4Client *iptables.IPTables aclMgr *aclManager - router *routerManager + router *router } // iFaceMapper defines subset methods of interface required for manager @@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouterManager(context, iptablesClient) + m.router, err = newRouter(context, iptablesClient, wgIface) if err != nil { log.Debugf("failed to initialize route related chains: %s", err) return nil, err } - m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName()) + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { log.Debugf("failed to initialize ACL manager: %s", err) return nil, err @@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -73,33 +74,62 @@ func (m *Manager) AddFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) + return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering( + sources [] netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.DeleteRule(rule) + return m.aclMgr.DeletePeerRule(rule) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.InsertRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) +} + +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + return firewall.SetLegacyManagement(m.router, isLegacy) } // Reset firewall to the default state @@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddFiltering( + _, err := m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error { if err != nil { return fmt.Errorf("failed to allow netbird interface traffic: %w", err) } - _, err = m.AddFiltering( + _, err = m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } + +func getConntrackEstablished() []string { + return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} +} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ceb116c6225..0072aa15961 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -14,6 +14,21 @@ import ( "github.com/netbirdio/netbird/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } - // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second) @@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) { t.Run("add first rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") for _, r := range rule1 { @@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) { port := &fw.Port{ Values: []int{8043: 8046}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") @@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) @@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") port := &fw.Port{Values: []int{5353}} - _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") + _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset() @@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering( + rule1, err = manager.AddPeerFiltering( ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "default", "accept HTTP traffic", ) @@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) { port := &fw.Port{ Values: []int{443}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "default", "accept HTTPS traffic from ports range", ) @@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") @@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") @@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e8f09a106c9..737b207854b 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -5,368 +5,478 @@ package iptables import ( "context" "fmt" + "net/netip" + "strconv" "strings" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" + "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) const ( - Ipv4Forwarding = "netbird-rt-forwarding" - ipv4Nat = "netbird-rt-nat" + ipv4Nat = "netbird-rt-nat" ) // constants needed to manage and create iptable rules const ( tableFilter = "filter" tableNat = "nat" - chainFORWARD = "FORWARD" chainPOSTROUTING = "POSTROUTING" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWD = "NETBIRD-RT-FWD" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" + + matchSet = "--match-set" ) -type routerManager struct { - ctx context.Context - stop context.CancelFunc - iptablesClient *iptables.IPTables - rules map[string][]string +type routeFilteringRuleParams struct { + Sources []netip.Prefix + Destination netip.Prefix + Proto firewall.Protocol + SPort *firewall.Port + DPort *firewall.Port + Direction firewall.RuleDirection + Action firewall.Action + SetName string +} + +type router struct { + ctx context.Context + stop context.CancelFunc + iptablesClient *iptables.IPTables + rules map[string][]string + ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + wgIface iFaceMapper + legacyManagement bool } -func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) { +func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { ctx, cancel := context.WithCancel(parentCtx) - m := &routerManager{ + r := &router{ ctx: ctx, stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), + wgIface: wgIface, + } + + r.ipsetCounter = refcounter.New( + r.createIpSet, + func(name string, _ struct{}) error { + return r.deleteIpSet(name) + }, + ) + + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) } - err := m.cleanUpDefaultForwardRules() + err := r.cleanUpDefaultForwardRules() if err != nil { - log.Errorf("failed to cleanup routing rules: %s", err) + log.Errorf("cleanup routing rules: %s", err) return nil, err } - err = m.createContainers() + err = r.createContainers() if err != nil { - log.Errorf("failed to create containers for route: %s", err) + log.Errorf("create containers for route: %s", err) } - return m, err + return r, err } -// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain -func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { - err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair) - if err != nil { - return err +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil + } + + var setName string + if len(sources) > 1 { + setName = firewall.GenerateSetName(sources) + if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } } - err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair)) - if err != nil { - return err + params := routeFilteringRuleParams{ + Sources: sources, + Destination: destination, + Proto: proto, + SPort: sPort, + DPort: dPort, + Action: action, + SetName: setName, } - if !pair.Masquerade { - return nil + rule := genRouteFilteringRuleSpec(params) + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return nil, fmt.Errorf("add route rule: %v", err) } - err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) - if err != nil { - return err - } + r.rules[string(ruleKey)] = rule - err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) - if err != nil { - return err + return ruleKey, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + ruleKey := rule.GetRuleID() + + if rule, exists := r.rules[ruleKey]; exists { + setName := r.findSetNameInRule(rule) + + if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("delete route rule: %v", err) + } + delete(r.rules, ruleKey) + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("failed to remove ipset: %w", err) + } + } + } else { + log.Debugf("route rule %s not found", ruleKey) } return nil } -// insertRoutingRule inserts an iptables rule -func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - var err error - - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) +func (r *router) findSetNameInRule(rule []string) string { + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + return rule[i+3] } - delete(i.rules, ruleKey) } + return "" +} - err = i.iptablesClient.Insert(table, chain, 1, rule...) - if err != nil { - return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) +func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { + if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) } - i.rules[ruleKey] = rule + for _, prefix := range sources { + if err := ipset.AddPrefix(setName, prefix); err != nil { + return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + } + } - return nil + return struct{}{}, nil } -// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains -func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error { - err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair) - if err != nil { - return err +func (r *router) deleteIpSet(setName string) error { + if err := ipset.Destroy(setName); err != nil { + return fmt.Errorf("destroy set %s: %w", setName, err) } + return nil +} - err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair)) - if err != nil { - return err +// AddNatRule inserts an iptables rule pair into the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } } if !pair.Masquerade { return nil } - err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair) - if err != nil { - return err + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) } - err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair)) - if err != nil { + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) + } + + return nil +} + +// RemoveNatRule removes an iptables rule pair from forwarding and nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } + + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if err := r.removeLegacyRouteRule(pair); err != nil { return err } + rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + r.rules[ruleKey] = rule + return nil } -func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error { - var err error +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) - ruleKey := firewall.GenKey(keyFormat, pair.ID) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } + delete(r.rules, ruleKey) + } else { + log.Debugf("legacy forwarding rule %s not found", ruleKey) } - delete(i.rules, ruleKey) return nil } -func (i *routerManager) RouteingFwChainName() string { - return chainRTFWD +// GetLegacyManagement returns the current legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement } -func (i *routerManager) Reset() error { - err := i.cleanUpDefaultForwardRules() - if err != nil { - return err +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } } - i.rules = make(map[string][]string) - return nil + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) Reset() error { + var merr *multierror.Error + if err := r.cleanUpDefaultForwardRules(); err != nil { + merr = multierror.Append(merr, err) + } + r.rules = make(map[string][]string) + + if err := r.ipsetCounter.Flush(); err != nil { + merr = multierror.Append(merr, err) + } + + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) cleanUpDefaultForwardRules() error { - err := i.cleanJumpRules() +func (r *router) cleanUpDefaultForwardRules() error { + err := r.cleanJumpRules() if err != nil { return err } log.Debug("flushing routing related tables") - ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTFWD, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD) - if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err) - return err + for _, chain := range []string{chainRTFWD, chainRTNAT} { + table := tableFilter + if chain == chainRTNAT { + table = tableNat } - } - ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTNAT, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT) + ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err) + log.Errorf("failed check chain %s, error: %v", chain, err) return err + } else if ok { + err = r.iptablesClient.ClearAndDeleteChain(table, chain) + if err != nil { + log.Errorf("failed cleaning chain %s, error: %v", chain, err) + return err + } } } + return nil } -func (i *routerManager) createContainers() error { - if i.rules[Ipv4Forwarding] != nil { - return nil +func (r *router) createContainers() error { + for _, chain := range []string{chainRTFWD, chainRTNAT} { + if err := r.createAndSetupChain(chain); err != nil { + return fmt.Errorf("create chain %s: %v", chain, err) + } } - errMSGFormat := "failed creating chain %s,error: %v" - err := i.createChain(tableFilter, chainRTFWD) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTFWD, err) + if err := r.insertEstablishedRule(chainRTFWD); err != nil { + return fmt.Errorf("insert established rule: %v", err) } - err = i.createChain(tableNat, chainRTNAT) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTNAT, err) - } + return r.addJumpRules() +} - err = i.addJumpRules() - if err != nil { - return fmt.Errorf("error while creating jump rules: %v", err) +func (r *router) createAndSetupChain(chain string) error { + table := r.getTableForChain(chain) + + if err := r.iptablesClient.NewChain(table, chain); err != nil { + return fmt.Errorf("failed creating chain %s, error: %v", chain, err) } return nil } -// addJumpRules create jump rules to send packets to NetBird chains -func (i *routerManager) addJumpRules() error { - rule := []string{"-j", chainRTFWD} - err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...) +func (r *router) getTableForChain(chain string) string { + if chain == chainRTNAT { + return tableNat + } + return tableFilter +} + +func (r *router) insertEstablishedRule(chain string) error { + establishedRule := getConntrackEstablished() + + err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...) if err != nil { - return err + return fmt.Errorf("failed to insert established rule: %v", err) } - i.rules[Ipv4Forwarding] = rule - rule = []string{"-j", chainRTNAT} - err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) + ruleKey := "established-" + chain + r.rules[ruleKey] = establishedRule + + return nil +} + +func (r *router) addJumpRules() error { + rule := []string{"-j", chainRTNAT} + err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) if err != nil { return err } - i.rules[ipv4Nat] = rule + r.rules[ipv4Nat] = rule return nil } -// cleanJumpRules cleans jump rules that was sending packets to NetBird chains -func (i *routerManager) cleanJumpRules() error { - var err error - errMSGFormat := "failed cleaning rule from chain %s,err: %v" - rule, found := i.rules[Ipv4Forwarding] - if found { - err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, chainFORWARD, err) - } - } - rule, found = i.rules[ipv4Nat] +func (r *router) cleanJumpRules() error { + rule, found := r.rules[ipv4Nat] if found { - err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) + err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) if err != nil { - return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err) + return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) } } - rules, err := i.iptablesClient.List("nat", "POSTROUTING") - if err != nil { - return fmt.Errorf("failed to list rules: %s", err) - } + return nil +} - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete postrouting jump rule: %s", err) +func (r *router) addNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { + return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) } + delete(r.rules, ruleKey) } - rules, err = i.iptablesClient.List(tableFilter, "FORWARD") - if err != nil { - return fmt.Errorf("failed to list rules in FORWARD chain: %s", err) + rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { + return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) } - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete FORWARD jump rule: %s", err) - } - } + r.rules[ruleKey] = rule + return nil } -func (i *routerManager) createChain(table, newChain string) error { - chains, err := i.iptablesClient.ListChains(table) - if err != nil { - return fmt.Errorf("couldn't get %s table chains, error: %v", table, err) - } +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) - shouldCreateChain := true - for _, chain := range chains { - if chain == newChain { - shouldCreateChain = false + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { + return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) } - } - if shouldCreateChain { - err = i.iptablesClient.NewChain(table, newChain) - if err != nil { - return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) - } - - // Add the loopback return rule to the NAT chain - loopbackRule := []string{"-o", "lo", "-j", "RETURN"} - err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...) - if err != nil { - return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err) - } + delete(r.rules, ruleKey) + } else { + log.Debugf("nat rule %s not found", ruleKey) + } - err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") - if err != nil { - return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) - } + return nil +} +func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { + intdir := "-i" + if inverse { + intdir = "-o" } - return nil + return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} } -// addNATRule appends an iptables rule pair to the nat chain -func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) - } - delete(i.rules, ruleKey) +func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { + var rule []string + + if params.SetName != "" { + rule = append(rule, "-m", "set", matchSet, params.SetName, "src") + } else if len(params.Sources) > 0 { + source := params.Sources[0] + rule = append(rule, "-s", source.String()) } - // inserting after loopback ignore rule - err := i.iptablesClient.Insert(table, chain, 2, rule...) - if err != nil { - return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) + rule = append(rule, "-d", params.Destination.String()) + + if params.Proto != firewall.ProtocolALL { + rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, applyPort("--sport", params.SPort)...) + rule = append(rule, applyPort("--dport", params.DPort)...) } - i.rules[ruleKey] = rule + rule = append(rule, "-j", actionToStr(params.Action)) - return nil + return rule } -// genRuleSpec generates rule specification -func genRuleSpec(jump, source, destination string) []string { - return []string{"-s", source, "-d", destination, "-j", jump} -} +func applyPort(flag string, port *firewall.Port) []string { + if port == nil { + return nil + } -func getIptablesRuleType(table string) string { - ruleType := "forwarding" - if table == tableNat { - ruleType = "nat" + if port.IsRange && len(port.Values) == 2 { + return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])} } - return ruleType + + if len(port.Values) > 1 { + portList := make([]string, len(port.Values)) + for i, p := range port.Values { + portList[i] = strconv.Itoa(p) + } + return []string{"-m", "multiport", flag, strings.Join(portList, ",")} + } + + return []string{flag, strconv.Itoa(port.Values[0])} } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 79b970c36af..6cede09e2b9 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -4,11 +4,13 @@ package iptables import ( "context" + "net/netip" "os/exec" "testing" "github.com/coreos/go-iptables/iptables" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") defer func() { @@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.Len(t, manager.rules, 2, "should have created rules map") - exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD) - require.True(t, exists, "forwarding rule should exist") - - exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.True(t, exists, "postrouting rule should exist") pair := firewall.RouterPair{ ID: "abc", - Source: "100.100.100.1/32", - Destination: "100.100.100.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.100.0/24"), Masquerade: true, } - forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination) + forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) require.NoError(t, err, "inserting rule should not return error") - nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination) + nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) require.NoError(t, err, "inserting rule should not return error") @@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, err, "shouldn't return error") } -func TestIptablesManager_InsertRoutingRules(t *testing.T) { +func TestIptablesManager_AddNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { @@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } }() - err = manager.InsertRoutingRules(testCase.InputPair) + err = manager.AddNatRule(testCase.InputPair) require.NoError(t, err, "forwarding pair should be inserted") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "forwarding rule should exist") - - foundRule, found := manager.rules[forwardRuleKey] - require.True(t, found, "forwarding rule should exist in the manager map") - require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "income forwarding rule should exist") + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - foundRule, found = manager.rules[inForwardRuleKey] - require.True(t, found, "income forwarding rule should exist in the manager map") - require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) if testCase.InputPair.Masquerade { require.True(t, exists, "nat rule should be created") @@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { require.False(t, foundNat, "nat rule should not exist in the map") } - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) @@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } } -func TestIptablesManager_RemoveRoutingRules(t *testing.T) { +func TestIptablesManager_RemoveNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { _ = manager.Reset() @@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "shouldn't return error") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) require.NoError(t, err, "inserting rule should not return error") - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) require.NoError(t, err, "inserting rule should not return error") @@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { err = manager.Reset() require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "forwarding rule should not exist") - - _, found := manager.rules[forwardRuleKey] - require.False(t, found, "forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "income forwarding rule should not exist") - - _, found = manager.rules[inForwardRuleKey] - require.False(t, found, "income forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.False(t, exists, "nat rule should not exist") - _, found = manager.rules[natRuleKey] + _, found := manager.rules[natRuleKey] require.False(t, found, "nat rule should exist in the manager map") exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) @@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { _, found = manager.rules[inNatRuleKey] require.False(t, found, "income nat rule should exist in the manager map") + }) + } +} + +func TestRouter_AddRouteFiltering(t *testing.T) { + if !isIptablesSupported() { + t.Skip("iptables not supported on this system") + } + + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err, "Failed to create iptables client") + + r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + require.NoError(t, err, "Failed to create router manager") + + defer func() { + err := r.Reset() + require.NoError(t, err, "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + // Log the internal rule + t.Logf("Internal rule: %v", rule) + + // Check if the rule exists in iptables + exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) + assert.NoError(t, err, "Failed to check rule existence") + assert.True(t, exists, "Rule not found in iptables") + + // Verify rule content + params := routeFilteringRuleParams{ + Sources: tt.sources, + Destination: tt.destination, + Proto: tt.proto, + SPort: tt.sPort, + DPort: tt.dPort, + Action: tt.action, + SetName: "", + } + + expectedRule := genRouteFilteringRuleSpec(params) + + if tt.expectSet { + setName := firewall.GenerateSetName(tt.sources) + params.SetName = setName + expectedRule = genRouteFilteringRuleSpec(params) + + // Check if the set was created + _, exists := r.ipsetCounter.Get(setName) + assert.True(t, exists, "IPSet not created") + } + + assert.Equal(t, expectedRule, rule, "Rule content mismatch") + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 6e4edb63e7c..a6185d3708e 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,15 +1,21 @@ package manager import ( + "crypto/sha256" + "encoding/hex" "fmt" "net" + "net/netip" + "sort" + "strings" + + log "github.com/sirupsen/logrus" ) const ( - NatFormat = "netbird-nat-%s" - ForwardingFormat = "netbird-fwd-%s" - InNatFormat = "netbird-nat-in-%s" - InForwardingFormat = "netbird-fwd-in-%s" + ForwardingFormatPrefix = "netbird-fwd-" + ForwardingFormat = "netbird-fwd-%s-%t" + NatFormat = "netbird-nat-%s-%t" ) // Rule abstraction should be implemented by each firewall manager @@ -49,11 +55,11 @@ type Manager interface { // AllowNetbird allows netbird interface traffic AllowNetbird() error - // AddFiltering rule to the firewall + // AddPeerFiltering adds a rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule - AddFiltering( + AddPeerFiltering( ip net.IP, proto Protocol, sPort *Port, @@ -64,17 +70,25 @@ type Manager interface { comment string, ) ([]Rule, error) - // DeleteRule from the firewall by rule definition - DeleteRule(rule Rule) error + // DeletePeerRule from the firewall by rule definition + DeletePeerRule(rule Rule) error // IsServerRouteSupported returns true if the firewall supports server side routing operations IsServerRouteSupported() bool - // InsertRoutingRules inserts a routing firewall rule - InsertRoutingRules(pair RouterPair) error + AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) + + // DeleteRouteRule deletes a routing rule + DeleteRouteRule(rule Rule) error + + // AddNatRule inserts a routing NAT rule + AddNatRule(pair RouterPair) error - // RemoveRoutingRules removes a routing firewall rule - RemoveRoutingRules(pair RouterPair) error + // RemoveNatRule removes a routing NAT rule + RemoveNatRule(pair RouterPair) error + + // SetLegacyManagement sets the legacy management mode + SetLegacyManagement(legacy bool) error // Reset firewall to the default state Reset() error @@ -83,6 +97,89 @@ type Manager interface { Flush() error } -func GenKey(format string, input string) string { - return fmt.Sprintf(format, input) +func GenKey(format string, pair RouterPair) string { + return fmt.Sprintf(format, pair.ID, pair.Inverse) +} + +// LegacyManager defines the interface for legacy management operations +type LegacyManager interface { + RemoveAllLegacyRouteRules() error + GetLegacyManagement() bool + SetLegacyManagement(bool) +} + +// SetLegacyManagement sets the route manager to use legacy management +func SetLegacyManagement(router LegacyManager, isLegacy bool) error { + oldLegacy := router.GetLegacyManagement() + + if oldLegacy != isLegacy { + router.SetLegacyManagement(isLegacy) + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to clean up the legacy rules + if !isLegacy && oldLegacy { + if err := router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + +// GenerateSetName generates a unique name for an ipset based on the given sources. +func GenerateSetName(sources []netip.Prefix) string { + // sort for consistent naming + sortPrefixes(sources) + + var sourcesStr strings.Builder + for _, src := range sources { + sourcesStr.WriteString(src.String()) + } + + hash := sha256.Sum256([]byte(sourcesStr.String())) + shortHash := hex.EncodeToString(hash[:])[:8] + + return fmt.Sprintf("nb-%s", shortHash) +} + +// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix +func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) == 0 { + return prefixes + } + + merged := []netip.Prefix{prefixes[0]} + for _, prefix := range prefixes[1:] { + last := merged[len(merged)-1] + if last.Contains(prefix.Addr()) { + // If the current prefix is contained within the last merged prefix, skip it + continue + } + if prefix.Contains(last.Addr()) { + // If the current prefix contains the last merged prefix, replace it + merged[len(merged)-1] = prefix + } else { + // Otherwise, add the current prefix to the merged list + merged = append(merged, prefix) + } + } + + return merged +} + +// sortPrefixes sorts the given slice of netip.Prefix in place. +// It sorts first by IP address, then by prefix length (most specific to least specific). +func sortPrefixes(prefixes []netip.Prefix) { + sort.Slice(prefixes, func(i, j int) bool { + addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) + if addrCmp != 0 { + return addrCmp < 0 + } + + // If IP addresses are the same, compare prefix lengths (longer prefixes first) + return prefixes[i].Bits() > prefixes[j].Bits() + }) } diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go new file mode 100644 index 00000000000..3f47d667929 --- /dev/null +++ b/client/firewall/manager/firewall_test.go @@ -0,0 +1,192 @@ +package manager_test + +import ( + "net/netip" + "reflect" + "regexp" + "testing" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +func TestGenerateSetName(t *testing.T) { + t.Run("Different orders result in same hash", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) + } + }) + + t.Run("Result format is correct", func(t *testing.T) { + prefixes := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + + result := manager.GenerateSetName(prefixes) + + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + if err != nil { + t.Fatalf("Error matching regex: %v", err) + } + if !matched { + t.Errorf("Result format is incorrect: %s", result) + } + }) + + t.Run("Empty input produces consistent result", func(t *testing.T) { + result1 := manager.GenerateSetName([]netip.Prefix{}) + result2 := manager.GenerateSetName([]netip.Prefix{}) + + if result1 != result2 { + t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) + } + }) + + t.Run("IPv4 and IPv6 mixing", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) + } + }) +} + +func TestMergeIPRanges(t *testing.T) { + tests := []struct { + name string + input []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Empty input", + input: []netip.Prefix{}, + expected: []netip.Prefix{}, + }, + { + name: "Single range", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Two non-overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + { + name: "One range containing another", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "One range containing another (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Overlapping ranges (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.128/25"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Multiple overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Partially overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + }, + { + name: "IPv6 ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::/48"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := manager.MergeIPRanges(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index b63a9f10432..8c94b7dd4c3 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,18 +1,26 @@ package manager +import ( + "net/netip" + + "github.com/netbirdio/netbird/route" +) + type RouterPair struct { - ID string - Source string - Destination string + ID route.ID + Source netip.Prefix + Destination netip.Prefix Masquerade bool + Inverse bool } -func GetInPair(pair RouterPair) RouterPair { +func GetInversePair(pair RouterPair) RouterPair { return RouterPair{ ID: pair.ID, // invert Source/Destination Source: pair.Destination, Destination: pair.Source, Masquerade: pair.Masquerade, + Inverse: true, } } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 1fa41b63a0c..85cba9e1cc2 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -33,9 +33,10 @@ const ( allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) +const flushError = "flush: %w" + var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00} + anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) type AclManager struct { @@ -48,7 +49,6 @@ type AclManager struct { chainInputRules *nftables.Chain chainOutputRules *nftables.Chain chainFwFilter *nftables.Chain - chainPrerouting *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -64,7 +64,7 @@ type iFaceMapper interface { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) - // and is permanent. Using same connection for booth type of operations + // and is permanent. Using same connection for both type of operations // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { @@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *AclManager) AddFiltering( +func (m *AclManager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering( } newRules = append(newRules, ioRule) - if !shouldAddToPrerouting(proto, dPort, direction) { - return newRules, nil - } - - preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip) - if err != nil { - return newRules, err - } - newRules = append(newRules, preroutingRule) return newRules, nil } -// DeleteRule from the firewall by rule definition -func (m *AclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error { return nil } -// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for -// input and output chains +// createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ &expr.Payload{ @@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error { Exprs: expOut, }) - err := m.rConn.Flush() - if err != nil { - log.Debugf("failed to create default allow rules: %s", err) - return err + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) } return nil } @@ -290,15 +278,11 @@ func (m *AclManager) Flush() error { log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) } - if err := m.refreshRuleHandles(m.chainPrerouting); err != nil { - log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err) - } - return nil } func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { - ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset) + ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ r.nftRule, @@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f }, nil } - ifaceKey := expr.MetaKeyIIFNAME - if direction == firewall.RuleDirectionOUT { - ifaceKey = expr.MetaKeyOIFNAME - } - expressions := []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - } + var expressions []expr.Any if proto != firewall.ProtocolALL { expressions = append(expressions, &expr.Payload{ @@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f Len: uint32(1), }) - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) + protoData, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %v", err) } + expressions = append(expressions, &expr.Cmp{ Register: 1, Op: expr.CmpOpEq, - Data: protoData, + Data: []byte{protoData}, }) } @@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f } else { chain = m.chainOutputRules } - nftRule := m.rConn.InsertRule(&nftables.Rule{ + nftRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: chain, - Position: 0, Exprs: expressions, UserData: userData, }) @@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f return rule, nil } -func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) { - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) - } - - ruleId := generateRuleIdForMangle(ipset, ip, proto, port) - if r, ok := m.rules[ruleId]; ok { - return &Rule{ - r.nftRule, - r.nftSet, - r.ruleID, - ip, - }, nil - } - - var ipExpression expr.Any - // add individual IP for match if no ipset defined - rawIP := ip.To4() - if ipset == nil { - ipExpression = &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: rawIP, - } - } else { - ipExpression = &expr.Lookup{ - SourceRegister: 1, - SetName: ipset.Name, - SetID: ipset.ID, - } - } - - expressions := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - ipExpression, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), - Len: uint32(1), - }, - &expr.Cmp{ - Register: 1, - Op: expr.CmpOpEq, - Data: protoData, - }, - } - - if port != nil { - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseTransportHeader, - Offset: 2, - Len: 2, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: encodePort(*port), - }, - ) - } - - expressions = append(expressions, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - ) - - nftRule := m.rConn.InsertRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainPrerouting, - Position: 0, - Exprs: expressions, - UserData: []byte(ruleId), - }) - - if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf("flush insert rule: %v", err) - } - - rule := &Rule{ - nftRule: nftRule, - nftSet: ipset, - ruleID: ruleId, - ip: ip, - } - - m.rules[ruleId] = rule - if ipset != nil { - m.ipsetStore.AddReferenceToIpset(ipset.Name) - } - return rule, nil -} - func (m *AclManager) createDefaultChains() (err error) { // chainNameInputRules chain := m.createChain(chainNameInputRules) err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chain.Name, err) - return err + return fmt.Errorf(flushError, err) } m.chainInputRules = chain @@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-input-filter // type filter hook input priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) - //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept - m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME) - m.addFwdAllow(chain, expr.MetaKeyIIFNAME) m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addDropExpressions(chain, expr.MetaKeyIIFNAME) err = m.rConn.Flush() @@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-output-filter // type filter hook output priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addDropExpressions(chain, expr.MetaKeyOIFNAME) @@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-forward-filter m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to - m.addMarkAccept() - m.addJumpRuleToInputChain() // to netbird-acl-input-rules + m.addJumpRulesToRtForward() // to netbird-rt-fwd m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) - return err + return fmt.Errorf(flushError, err) } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - m.chainPrerouting = m.createPreroutingMangle() - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err) - return err - } return nil } @@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() { Chain: m.chainFwFilter, Exprs: expressions, }) - - expressions = []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addMarkAccept() { - // oifname "wt0" meta mark 0x000007e4 accept - // iifname "wt0" meta mark 0x000007e4 accept - ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME} - for _, iface := range ifaces { - expressions := []expr.Any{ - &expr.Meta{Key: iface, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: postroutingMark, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) - } } func (m *AclManager) createChain(name string) *nftables.Chain { @@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain { } chain = m.rConn.AddChain(chain) + + insertReturnTrafficRule(m.rConn, m.workTable, chain) + return chain } @@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha return m.rConn.AddChain(chain) } -func (m *AclManager) createPreroutingMangle() *nftables.Chain { - polAccept := nftables.ChainPolicyAccept - chain := &nftables.Chain{ - Name: "netbird-acl-prerouting-filter", - Table: m.workTable, - Hooknum: nftables.ChainHookPrerouting, - Priority: nftables.ChainPriorityMangle, - Type: nftables.ChainTypeFilter, - Policy: &polAccept, - } - - chain = m.rConn.AddChain(chain) - - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: chain, - Exprs: expressions, - }) - return chain -} - func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, @@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addJumpRuleToInputChain() { - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.chainInputRules.Name, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if netIfName == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } else { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } - expressions := []expr.Any{ - &expr.Meta{Key: netIfName, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if iifname == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } else { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } + dstOp := expr.CmpOpNeq expressions := []expr.Any{ &expr.Meta{Key: iifname, Register: 1}, &expr.Cmp{ @@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Payload{ DestRegister: 2, Base: expr.PayloadBaseNetworkHeader, @@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { } func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ @@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Verdict{ Kind: expr.VerdictJump, Chain: to, }, } + _ = m.rConn.AddRule(&nftables.Rule{ Table: chain.Table, Chain: chain, @@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { return nil } -func generateRuleId( +func generatePeerRuleId( ip net.IP, sPort *firewall.Port, dPort *firewall.Port, @@ -1155,33 +718,6 @@ func generateRuleId( } return "set:" + ipset.Name + rulesetID } -func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string { - // case of icmp port is empty - var p string - if port != nil { - p = port.String() - } - if ipset != nil { - return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p) - } else { - return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p) - } -} - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil && proto != firewall.ProtocolICMP { - return false - } - return true -} func encodePort(port firewall.Port) []byte { bs := make([]byte, 2) @@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte { func ifname(n string) []byte { b := make([]byte, 16) - copy(b, []byte(n+"\x00")) + copy(b, n+"\x00") return b } + +func protoToInt(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return unix.IPPROTO_ICMP, nil + } + + return 0, fmt.Errorf("unsupported protocol: %s", protocol) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a376c98c316..d2258ae0869 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -5,9 +5,11 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" @@ -15,8 +17,11 @@ import ( ) const ( - // tableName is the name of the table that is used for filtering by the Netbird client - tableName = "netbird" + // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + tableNameNetbird = "netbird" + + tableNameFilter = "filter" + chainNameInput = "INPUT" ) // Manager of iptables firewall @@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return nil, err } - m.router, err = newRouter(context, workTable) + m.router, err = newRouter(context, workTable, wgIface) if err != nil { return nil, err } - m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName()) + m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { return nil, err } @@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -76,33 +81,52 @@ func (m *Manager) AddFiltering( return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) } - return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) + return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) +} + +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclManager.DeleteRule(rule) + return m.aclManager.DeletePeerRule(rule) +} + +// DeleteRouteRule deletes a routing rule +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) } // AllowNetbird allows netbird interface traffic @@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameForward { chain = c break } @@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error { return nil } +// SetLegacyManagement sets the route manager to use legacy management +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + oldLegacy := m.router.legacyManagement + + if oldLegacy != isLegacy { + m.router.legacyManagement = isLegacy + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to cleanup the legacy rules + if !isLegacy && oldLegacy { + if err := m.router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + // Reset firewall to the default state func (m *Manager) Reset() error { m.mutex.Lock() @@ -185,14 +230,16 @@ func (m *Manager) Reset() error { } } - m.router.ResetForwardRules() + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset forward rules: %v", err) + } tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list of tables: %w", err) } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } @@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } @@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, + &expr.Verdict{}, }, UserData: []byte(allowNetbirdInputRuleID), } @@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { ifName := ifname(m.wgIface.Name()) for _, rule := range existedRules { - if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { + if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if len(rule.Exprs) < 4 { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { continue @@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable } return nil } + +func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + + conn.InsertRule(rule) +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 1f226e315a2..7f78a9a2e02 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" @@ -17,6 +18,21 @@ import ( "github.com/netbirdio/netbird/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.96.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("100.96.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("100.96.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("100.96.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second * 3) @@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) { testClient := &nftables.Conn{} - rule, err := manager.AddFiltering( + rule, err := manager.AddPeerFiltering( ip, fw.ProtocolTCP, nil, @@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) { rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 1, "expected 1 rules") + require.Len(t, rules, 2, "expected 2 rules") - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - expectedExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + expectedExprs1 := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, &expr.Cmp{ - Op: expr.CmpOpEq, + Op: expr.CmpOpNeq, Register: 1, - Data: ifname("lo"), + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, }, + } + require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") + + ipToAdd, _ := netip.AddrFromSlice(ip) + add := ipToAdd.Unmap() + expectedExprs2 := []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, @@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) { }, &expr.Verdict{Kind: expr.VerdictDrop}, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") + require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") for _, r := range rule { - err = manager.DeleteRule(r) + err = manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) { rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 0, "expected 0 rules after deletion") + // established rule remains + require.Len(t, rules, 1, "expected 1 rules after deletion") err = manager.Reset() require.NoError(t, err, "failed to reset") @@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go deleted file mode 100644 index 71d5ac88e37..00000000000 --- a/client/firewall/nftables/route_linux.go +++ /dev/null @@ -1,431 +0,0 @@ -package nftables - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/google/nftables" - "github.com/google/nftables/binaryutil" - "github.com/google/nftables/expr" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/firewall/manager" -) - -const ( - chainNameRouteingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" - - userDataAcceptForwardRuleSrc = "frwacceptsrc" - userDataAcceptForwardRuleDst = "frwacceptdst" - - loopbackInterface = "lo\x00" -) - -// some presets for building nftable rules -var ( - zeroXor = binaryutil.NativeEndian.PutUint32(0) - - exprCounterAccept = []expr.Any{ - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") -) - -type router struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - workTable *nftables.Table - filterTable *nftables.Table - chains map[string]*nftables.Chain - // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules - rules map[string]*nftables.Rule - isDefaultFwdRulesEnabled bool -} - -func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - - r := &router{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - workTable: workTable, - chains: make(map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - } - - var err error - r.filterTable, err = r.loadFilterTable() - if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, err - } - } - - err = r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) - } - return r, err -} - -func (r *router) RouteingFwChainName() string { - return chainNameRouteingFw -} - -// ResetForwardRules cleans existing nftables default forward rules from the system -func (r *router) ResetForwardRules() { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to reset forward rules: %s", err) - } -} - -func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) - if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) - } - - for _, table := range tables { - if table.Name == "filter" { - return table, nil - } - } - - return nil, errFilterTableNotFound -} - -func (r *router) createContainers() error { - - r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRouteingFw, - Table: r.workTable, - }) - - r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRoutingNat, - Table: r.workTable, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - - // Add RETURN rule for loopback interface - loRule := &nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainNameRoutingNat], - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte(loopbackInterface), - }, - &expr.Verdict{Kind: expr.VerdictReturn}, - }, - } - r.conn.InsertRule(loRule) - - err := r.refreshRulesMap() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to initialize table: %v", err) - } - return nil -} - -// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (r *router) AddRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) - if err != nil { - return err - } - - if pair.Masquerade { - err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) - if err != nil { - return err - } - } - - if r.filterTable != nil && !r.isDefaultFwdRulesEnabled { - log.Debugf("add default accept forward rule") - r.acceptForwardRule(pair.Source) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err) - } - return nil -} - -// addRoutingRule inserts a nftable rule to the conn client flush queue -func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) - - var expression []expr.Any - if isNat { - expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic - } else { - expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic - } - - ruleKey := manager.GenKey(format, pair.ID) - - _, exists := r.rules[ruleKey] - if exists { - err := r.removeRoutingRule(format, pair) - if err != nil { - return err - } - } - - r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainName], - Exprs: expression, - UserData: []byte(ruleKey), - }) - return nil -} - -func (r *router) acceptForwardRule(sourceNetwork string) { - src := generateCIDRMatcherExpressions(true, sourceNetwork) - dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0") - - var exprs []expr.Any - exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleSrc), - } - - r.conn.AddRule(rule) - - src = generateCIDRMatcherExpressions(true, "0.0.0.0/0") - dst = generateCIDRMatcherExpressions(false, sourceNetwork) - - exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule = &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleDst), - } - r.conn.AddRule(rule) - r.isDefaultFwdRulesEnabled = true -} - -// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains -func (r *router) RemoveRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.ForwardingFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.NatFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - if len(r.rules) == 0 { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) - } - log.Debugf("nftables: removed rules for %s", pair.Destination) - return nil -} - -// removeRoutingRule add a nftable rule to the removal queue and delete from rules map -func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error { - ruleKey := manager.GenKey(format, pair.ID) - - rule, found := r.rules[ruleKey] - if found { - ruleType := "forwarding" - if rule.Chain.Type == nftables.ChainTypeNAT { - ruleType = "nat" - } - - err := r.conn.DelRule(rule) - if err != nil { - return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err) - } - - log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination) - - delete(r.rules, ruleKey) - } - return nil -} - -// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid -// duplicates and to get missing attributes that we don't have when adding new rules -func (r *router) refreshRulesMap() error { - for _, chain := range r.chains { - rules, err := r.conn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) - } - for _, rule := range rules { - if len(rule.UserData) > 0 { - r.rules[string(rule.UserData)] = rule - } - } - } - return nil -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - r.isDefaultFwdRulesEnabled = false - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return err - } - - var rules []*nftables.Rule - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { - continue - } - if chain.Name != "FORWARD" { - continue - } - - rules, err = r.conn.GetRules(r.filterTable, chain) - if err != nil { - return err - } - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) { - err := r.conn.DelRule(rule) - if err != nil { - return err - } - } - } - r.isDefaultFwdRulesEnabled = false - return r.conn.Flush() -} - -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any { - ip, network, _ := net.ParseCIDR(cidr) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - - var offSet uint32 - if source { - offSet = 12 // src offset - } else { - offSet = 16 // dst offset - } - - return []expr.Any{ - // fetch src add - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: offSet, - Len: 4, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: 4, - Mask: network.Mask, - Xor: zeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: add.AsSlice(), - }, - } -} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go new file mode 100644 index 00000000000..aa61e18585f --- /dev/null +++ b/client/firewall/nftables/router_linux.go @@ -0,0 +1,798 @@ +package nftables + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "net" + "net/netip" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" +) + +const ( + chainNameRoutingFw = "netbird-rt-fwd" + chainNameRoutingNat = "netbird-rt-nat" + chainNameForward = "FORWARD" + + userDataAcceptForwardRuleIif = "frwacceptiif" + userDataAcceptForwardRuleOif = "frwacceptoif" +) + +const refreshRulesMapError = "refresh rules map: %w" + +var ( + errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") +) + +type router struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + workTable *nftables.Table + filterTable *nftables.Table + chains map[string]*nftables.Chain + // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules + rules map[string]*nftables.Rule + ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { + ctx, cancel := context.WithCancel(parentCtx) + + r := &router{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + workTable: workTable, + chains: make(map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + wgIface: wgIface, + } + + r.ipsetCounter = refcounter.New( + r.createIpSet, + r.deleteIpSet, + ) + + var err error + r.filterTable, err = r.loadFilterTable() + if err != nil { + if errors.Is(err, errFilterTableNotFound) { + log.Warnf("table 'filter' not found for forward rules") + } else { + return nil, err + } + } + + err = r.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.createContainers() + if err != nil { + log.Errorf("failed to create containers for route: %s", err) + } + return r, err +} + +// Reset cleans existing nftables default forward rules from the system +func (r *router) Reset() error { + // clear without deleting the ipsets, the nf table will be deleted by the caller + r.ipsetCounter.Clear() + + return r.cleanUpDefaultForwardRules() +} + +func (r *router) cleanUpDefaultForwardRules() error { + if r.filterTable == nil { + return nil + } + + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + return r.conn.Flush() +} + +func (r *router) loadFilterTable() (*nftables.Table, error) { + tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + } + + for _, table := range tables { + if table.Name == "filter" { + return table, nil + } + } + + return nil, errFilterTableNotFound +} + +func (r *router) createContainers() error { + + r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingFw, + Table: r.workTable, + }) + + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingNat, + Table: r.workTable, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + + r.acceptForwardRules() + + err := r.refreshRulesMap() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + return nil +} + +// AddRouteFiltering appends a nftables rule to the routing chain +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil + } + + chain := r.chains[chainNameRoutingFw] + var exprs []expr.Any + + switch { + case len(sources) == 1 && sources[0].Bits() == 0: + // If it's 0.0.0.0/0, we don't need to add any source matching + case len(sources) == 1: + // If there's only one source, we can use it directly + exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + default: + // If there are multiple sources, create or get an ipset + var err error + exprs, err = r.getIpSetExprs(sources, exprs) + if err != nil { + return nil, fmt.Errorf("get ipset expressions: %w", err) + } + } + + // Handle destination + exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + + // Handle protocol + if proto != firewall.ProtocolALL { + protoNum, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %w", err) + } + exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}) + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }) + + exprs = append(exprs, applyPort(sPort, true)...) + exprs = append(exprs, applyPort(dPort, false)...) + } + + exprs = append(exprs, &expr.Counter{}) + + var verdict expr.VerdictKind + if action == firewall.ActionAccept { + verdict = expr.VerdictAccept + } else { + verdict = expr.VerdictDrop + } + exprs = append(exprs, &expr.Verdict{Kind: verdict}) + + rule := &nftables.Rule{ + Table: r.workTable, + Chain: chain, + Exprs: exprs, + UserData: []byte(ruleKey), + } + + r.rules[string(ruleKey)] = r.conn.AddRule(rule) + + return ruleKey, r.conn.Flush() +} + +func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { + setName := firewall.GenerateSetName(sources) + ref, err := r.ipsetCounter.Increment(setName, sources) + if err != nil { + return nil, fmt.Errorf("create or get ipset for sources: %w", err) + } + + exprs = append(exprs, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + ) + return exprs, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleKey := rule.GetRuleID() + nftRule, exists := r.rules[ruleKey] + if !exists { + log.Debugf("route rule %s not found", ruleKey) + return nil + } + + setName := r.findSetNameInRule(nftRule) + + if err := r.deleteNftRule(nftRule, ruleKey); err != nil { + return fmt.Errorf("delete: %w", err) + } + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("decrement ipset reference: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { + // overlapping prefixes will result in an error, so we need to merge them + sources = firewall.MergeIPRanges(sources) + + set := &nftables.Set{ + Name: setName, + Table: r.workTable, + // required for prefixes + Interval: true, + KeyType: nftables.TypeIPAddr, + } + + var elements []nftables.SetElement + for _, prefix := range sources { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + + // nftables needs half-open intervals [firstIP, lastIP) for prefixes + // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc + firstIP := prefix.Addr() + lastIP := calculateLastIP(prefix).Next() + + elements = append(elements, + // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + nftables.SetElement{Key: firstIP.AsSlice()}, + nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, + ) + } + + if err := r.conn.AddSet(set, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return set, nil +} + +// calculateLastIP determines the last IP in a given prefix. +func calculateLastIP(prefix netip.Prefix) netip.Addr { + hostMask := ^uint32(0) >> prefix.Masked().Bits() + lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + + return netip.AddrFrom4(uint32ToBytes(lastIP)) +} + +// Utility function to convert netip.Addr to uint32. +func uint32FromNetipAddr(addr netip.Addr) uint32 { + b := addr.As4() + return binary.BigEndian.Uint32(b[:]) +} + +// Utility function to convert uint32 to a netip-compatible byte slice. +func uint32ToBytes(ip uint32) [4]byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], ip) + return b +} + +func (r *router) deleteIpSet(setName string, set *nftables.Set) error { + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("Deleted unused ipset %s", setName) + return nil +} + +func (r *router) findSetNameInRule(rule *nftables.Rule) string { + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + return lookup.SetName + } + } + return "" +} + +func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule %s: %w", ruleKey, err) + } + delete(r.rules, ruleKey) + + log.Debugf("removed route rule %s", ruleKey) + + return nil +} + +// AddNatRule appends a nftables rule pair to the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } + } + + if pair.Masquerade { + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) + } + + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + } + + return nil +} + +// addNatRule inserts a nftables rule to the conn client flush queue +func (r *router) addNatRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + dir := expr.MetaKeyIIFNAME + if pair.Inverse { + dir = expr.MetaKeyOIFNAME + } + + intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ + &expr.Meta{ + Key: dir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) + exprs = append(exprs, + &expr.Counter{}, &expr.Masq{}, + ) + + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs, + UserData: []byte(ruleKey), + }) + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingFw], + Exprs: expression, + UserData: []byte(ruleKey), + }) + return nil +} + +// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + } + + return nil +} + +// GetLegacyManagement returns the route manager's legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} + +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.conn.DelRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure +// that our traffic is not dropped by existing rules there. +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +func (r *router) acceptForwardRules() { + if r.filterTable == nil { + log.Debugf("table 'filter' not found for forward rules, skipping accept rules") + return + } + + intf := ifname(r.wgIface.Name()) + + // Rule for incoming interface (iif) with counter + iifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleIif), + } + r.conn.InsertRule(iifRule) + + // Rule for outgoing interface (oif) with counter + oifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 2, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 2, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleOif), + } + + r.conn.InsertRule(oifRule) +} + +// RemoveNatRule removes a nftables rule pair from nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } + + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + } + + log.Debugf("nftables: removed rules for %s", pair.Destination) + return nil +} + +// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + err := r.conn.DelRule(rule) + if err != nil { + return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: nat rule %s not found", ruleKey) + } + + return nil +} + +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules +func (r *router) refreshRulesMap() error { + for _, chain := range r.chains { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("nftables: unable to list rules: %v", err) + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + r.rules[string(rule.UserData)] = rule + } + } + } + return nil +} + +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR +func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { + var offset uint32 + if source { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + ones := prefix.Bits() + // 0.0.0.0/0 doesn't need extra expressions + if ones == 0 { + return nil + } + + mask := net.CIDRMask(ones, 32) + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + // netmask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0, 0, 0, 0}, + }, + // net address + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: prefix.Masked().Addr().AsSlice(), + }, + } +} + +func applyPort(port *firewall.Port, isSource bool) []expr.Any { + if port == nil { + return nil + } + + var exprs []expr.Any + + offset := uint32(2) // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + exprs = append(exprs, &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: offset, + Len: 2, + }) + + if port.IsRange && len(port.Values) == 2 { + // Handle port range + exprs = append(exprs, + &expr.Cmp{ + Op: expr.CmpOpGte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])), + }, + &expr.Cmp{ + Op: expr.CmpOpLte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])), + }, + ) + } else { + // Handle single port or multiple ports + for i, p := range port.Values { + if i > 0 { + // Add a bitwise OR operation between port checks + exprs = append(exprs, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: []byte{0x00, 0x00, 0xff, 0xff}, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }) + } + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(p)), + }) + } + } + + return exprs +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 913fbd5d2a3..bbf92f3beaf 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -4,11 +4,15 @@ package nftables import ( "context" + "encoding/binary" + "net/netip" + "os/exec" "testing" "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" "github.com/google/nftables/expr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -24,56 +28,50 @@ const ( NFTABLES ) -func TestNftablesManager_InsertRoutingRules(t *testing.T) { +func TestNftablesManager_AddNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) require.NoError(t, err, "shouldn't return error") - err = manager.AddRoutingRules(testCase.InputPair) - defer func() { - _ = manager.RemoveRoutingRules(testCase.InputPair) - }() - require.NoError(t, err, "forwarding pair should be inserted") + err = manager.AddNatRule(testCase.InputPair) + require.NoError(t, err, "pair should be inserted") - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - - found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") + defer func(manager *router, pair firewall.RouterPair) { + require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") + }(manager, testCase.InputPair) if testCase.InputPair.Masquerade { - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { require.Equal(t, 1, found, "should find at least 1 rule to test") } - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - testingExpression = append(sourceExp, destExp...) //nolint:gocritic - inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - - found = 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") - if testCase.InputPair.Masquerade { - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { } require.Equal(t, 1, found, "should find at least 1 rule to test") } + }) } } -func TestNftablesManager_RemoveRoutingRules(t *testing.T) { +func TestNftablesManager_RemoveNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(forwardRuleKey), - }) - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { UserData: []byte(natRuleKey), }) - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - - forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(inForwardRuleKey), - }) + sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) + destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { err = nftablesTestingClient.Flush() require.NoError(t, err, "shouldn't return error") - manager.ResetForwardRules() + err = manager.Reset() + require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") for _, chain := range manager.chains { @@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) for _, rule := range rules { if len(rule.UserData) > 0 { - require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") } } @@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { } } +func TestRouter_AddRouteFiltering(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func(r *router) { + require.NoError(t, r.Reset(), "Failed to reset rules") + }(r) + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + t.Log("Internal rule expressions:") + for i, expr := range rule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify internal rule content + verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Check if the rule exists in nftables and verify its content + rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw]) + require.NoError(t, err, "Failed to get rules from nftables") + + var nftRule *nftables.Rule + for _, rule := range rules { + if string(rule.UserData) == ruleKey.GetRuleID() { + nftRule = rule + break + } + } + + require.NotNil(t, nftRule, "Rule not found in nftables") + t.Log("Actual nftables rule expressions:") + for i, expr := range nftRule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify actual nftables rule content + verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") + }) + } +} + +func TestNftablesCreateIpSet(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IP", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + }, + { + name: "Multiple IPs", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("172.16.0.1/32"), + }, + }, + { + name: "Single Subnet", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + { + name: "Multiple Subnets with Various Prefix Lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("203.0.113.0/26"), + }, + }, + { + name: "Mix of Single IPs and Subnets in Different Positions", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("172.16.0.1/32"), + netip.MustParsePrefix("203.0.113.0/24"), + }, + }, + { + name: "Overlapping IPs/Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.1/32"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + } + + // Add this helper function inside TestNftablesCreateIpSet + printNftSets := func() { + cmd := exec.Command("nft", "list", "sets") + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Failed to run 'nft list sets': %v", err) + } else { + t.Logf("Current nft sets:\n%s", output) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.GenerateSetName(tt.sources) + set, err := r.createIpSet(setName, tt.sources) + if err != nil { + t.Logf("Failed to create IP set: %v", err) + printNftSets() + require.NoError(t, err, "Failed to create IP set") + } + require.NotNil(t, set, "Created set is nil") + + // Verify set properties + assert.Equal(t, setName, set.Name, "Set name mismatch") + assert.Equal(t, r.workTable, set.Table, "Set table mismatch") + assert.True(t, set.Interval, "Set interval property should be true") + assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch") + + // Fetch the created set from nftables + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + require.NotNil(t, fetchedSet, "Fetched set is nil") + + // Verify set elements + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + // Count the number of unique prefixes (excluding interval end markers) + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + // Check against expected merged prefixes + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected") + + // Verify each expected prefix is in the set + for _, expected := range tt.expected { + found := false + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + if expected.Contains(ip) { + found = true + break + } + } + } + assert.True(t, found, "Expected prefix %s not found in set", expected) + } + + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + t.Logf("Failed to delete set: %v", err) + printNftSets() + } + require.NoError(t, err, "Failed to delete set") + }) + } +} + +func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { + t.Helper() + + assert.NotNil(t, rule, "Rule should not be nil") + + // Verify sources and destination + if expectSet { + assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources") + } else if len(sources) == 1 && sources[0].Bits() != 0 { + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0]) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0]) + } + } + + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination) + } + + // Verify protocol + if proto != firewall.ProtocolALL { + assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto) + } + + // Verify ports + if sPort != nil { + assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort) + } + if dPort != nil { + assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort) + } + + // Verify action + assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action) +} + +func containsSetLookup(exprs []expr.Any) bool { + for _, e := range exprs { + if _, ok := e.(*expr.Lookup); ok { + return true + } + } + return false +} + +func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool { + var offset uint32 + if isSource { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + var payloadFound, bitwiseFound, cmpFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 { + payloadFound = true + } + case *expr.Bitwise: + if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 { + bitwiseFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 { + cmpFound = true + } + } + } + return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0 +} + +func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { + var offset uint32 = 2 // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + var payloadFound, portMatchFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { + payloadFound = true + } + case *expr.Cmp: + if port.IsRange { + if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { + portMatchFound = true + } + } else { + if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { + portValue := binary.BigEndian.Uint16(ex.Data) + for _, p := range port.Values { + if uint16(p) == portValue { + portMatchFound = true + break + } + } + } + } + } + if payloadFound && portMatchFound { + return true + } + } + return false +} + +func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { + var metaFound, cmpFound bool + expectedProto, _ := protoToInt(proto) + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Meta: + if ex.Key == expr.MetaKeyL4PROTO { + metaFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto { + cmpFound = true + } + } + } + return metaFound && cmpFound +} + +func containsAction(exprs []expr.Any, action firewall.Action) bool { + for _, e := range exprs { + if verdict, ok := e.(*expr.Verdict); ok { + switch action { + case firewall.ActionAccept: + return verdict.Kind == expr.VerdictAccept + case firewall.ActionDrop: + return verdict.Kind == expr.VerdictDrop + } + } + } + return false +} + // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. func check() int { nf := nftables.Conn{} @@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } - table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = sConn.Flush() return table, err @@ -273,7 +708,7 @@ func deleteWorkTable() { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 432d113dd46..267e93efdbc 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -1,8 +1,10 @@ -//go:build !android - package test -import firewall "github.com/netbirdio/netbird/client/firewall/manager" +import ( + "net/netip" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) var ( InsertRuleTestCases = []struct { @@ -13,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: false, }, }, @@ -22,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, @@ -38,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 75792e9c06b..681058ea949 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "net/netip" "sync" "github.com/google/gopacket" @@ -103,26 +104,26 @@ func (m *Manager) IsServerRouteSupported() bool { } } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.InsertRoutingRules(pair) + return m.nativeFirewall.AddNatRule(pair) } -// RemoveRoutingRules removes a routing firewall rule -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +// RemoveNatRule removes a routing firewall rule +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.RemoveRoutingRules(pair) + return m.nativeFirewall.RemoveNatRule(pair) } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -188,8 +189,22 @@ func (m *Manager) AddFiltering( return []firewall.Rule{&r}, nil } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { + if m.nativeFirewall == nil { + return nil, errRouteNotSupported + } + return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.DeleteRouteRule(rule) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -215,6 +230,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error { return nil } +// SetLegacyManagement doesn't need to be implemented for this manager +func (m *Manager) SetLegacyManagement(_ bool) error { + return nil +} + // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -395,7 +415,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } @@ -403,7 +423,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 514a9053935..dd7366fe93d 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -49,7 +49,7 @@ func TestManagerCreate(t *testing.T) { } } -func TestManagerAddFiltering(t *testing.T) { +func TestManagerAddPeerFiltering(t *testing.T) { isSetFilterCalled := false ifaceMock := &IFaceMock{ SetFilterFunc: func(iface.PacketFilter) error { @@ -71,7 +71,7 @@ func TestManagerAddFiltering(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -106,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -119,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) { action = fw.ActionDrop comment = "Test rule 2" - rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return } for _, r := range rule { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -140,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -252,7 +252,7 @@ func TestManagerReset(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -290,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) { action := fw.ActionAccept comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -406,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go new file mode 100644 index 00000000000..e27fce439fc --- /dev/null +++ b/client/internal/acl/id/id.go @@ -0,0 +1,25 @@ +package id + +import ( + "fmt" + "net/netip" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +type RuleID string + +func (r RuleID) GetRuleID() string { + return string(r) +} + +func GenerateRouteRuleKey( + sources []netip.Prefix, + destination netip.Prefix, + proto manager.Protocol, + sPort *manager.Port, + dPort *manager.Port, + action manager.Action, +) RuleID { + return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index fd2c2c875d1..ce2a12af16f 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "net" + "net/netip" "strconv" "sync" "time" @@ -12,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -23,16 +25,18 @@ type Manager interface { // DefaultManager uses firewall manager to handle type DefaultManager struct { - firewall firewall.Manager - ipsetCounter int - rulesPairs map[string][]firewall.Rule - mutex sync.Mutex + firewall firewall.Manager + ipsetCounter int + peerRulesPairs map[id.RuleID][]firewall.Rule + routeRules map[id.RuleID]struct{} + mutex sync.Mutex } func NewDefaultManager(fm firewall.Manager) *DefaultManager { return &DefaultManager{ - firewall: fm, - rulesPairs: make(map[string][]firewall.Rule), + firewall: fm, + peerRulesPairs: make(map[id.RuleID][]firewall.Rule), + routeRules: make(map[id.RuleID]struct{}), } } @@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { start := time.Now() defer func() { total := 0 - for _, pairs := range d.rulesPairs { + for _, pairs := range d.peerRulesPairs { total += len(pairs) } log.Infof( @@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { return } - defer func() { - if err := d.firewall.Flush(); err != nil { - log.Error("failed to flush firewall rules: ", err) - } - }() + d.applyPeerACLs(networkMap) + + // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag, + // then the mgmt server is older than the client, and we need to allow all traffic for routes + isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty + if err := d.firewall.SetLegacyManagement(isLegacy); err != nil { + log.Errorf("failed to set legacy management flag: %v", err) + } + + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + log.Errorf("Failed to apply route ACLs: %v", err) + } + if err := d.firewall.Flush(); err != nil { + log.Error("failed to flush firewall rules: ", err) + } +} + +func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { rules, squashedProtocols := d.squashAcceptRules(networkMap) - enableSSH := (networkMap.PeerConfig != nil && + enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled) - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + networkMap.PeerConfig.SshConfig.SshEnabled + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { enableSSH = enableSSH && !ok } - if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { enableSSH = enableSSH && !ok } @@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: strconv.Itoa(ssh.DefaultSSHPort), }) } @@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, ) } - newRulePairs := make(map[string][]firewall.Rule) + newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) for _, r := range rules { @@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { break } if len(rules) > 0 { - d.rulesPairs[pairID] = rulePair + d.peerRulesPairs[pairID] = rulePair newRulePairs[pairID] = rulePair } } - for pairID, rules := range d.rulesPairs { + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { - log.Errorf("failed to delete firewall rule: %v", err) + if err := d.firewall.DeletePeerRule(rule); err != nil { + log.Errorf("failed to delete peer firewall rule: %v", err) continue } } - delete(d.rulesPairs, pairID) + delete(d.peerRulesPairs, pairID) + } + } + d.peerRulesPairs = newRulePairs +} + +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { + var newRouteRules = make(map[id.RuleID]struct{}) + for _, rule := range rules { + id, err := d.applyRouteACL(rule) + if err != nil { + return fmt.Errorf("apply route ACL: %w", err) + } + newRouteRules[id] = struct{}{} + } + + for id := range d.routeRules { + if _, ok := newRouteRules[id]; !ok { + if err := d.firewall.DeleteRouteRule(id); err != nil { + log.Errorf("failed to delete route firewall rule: %v", err) + continue + } + delete(d.routeRules, id) } } - d.rulesPairs = newRulePairs + d.routeRules = newRouteRules + return nil +} + +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { + if len(rule.SourceRanges) == 0 { + return "", fmt.Errorf("source ranges is empty") + } + + var sources []netip.Prefix + for _, sourceRange := range rule.SourceRanges { + source, err := netip.ParsePrefix(sourceRange) + if err != nil { + return "", fmt.Errorf("parse source range: %w", err) + } + sources = append(sources, source) + } + + var destination netip.Prefix + if rule.IsDynamic { + destination = getDefault(sources[0]) + } else { + var err error + destination, err = netip.ParsePrefix(rule.Destination) + if err != nil { + return "", fmt.Errorf("parse destination: %w", err) + } + } + + protocol, err := convertToFirewallProtocol(rule.Protocol) + if err != nil { + return "", fmt.Errorf("invalid protocol: %w", err) + } + + action, err := convertFirewallAction(rule.Action) + if err != nil { + return "", fmt.Errorf("invalid action: %w", err) + } + + dPorts := convertPortInfo(rule.PortInfo) + + addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) + if err != nil { + return "", fmt.Errorf("add route rule: %w", err) + } + + return id.RuleID(addedRule.GetRuleID()), nil } func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, -) (string, []firewall.Rule, error) { +) (id.RuleID, []firewall.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") @@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule( } } - ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "") - if rulesPair, ok := d.rulesPairs[ruleID]; ok { + ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") + if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { return ruleID, rulesPair, nil } var rules []firewall.Rule switch r.Direction { - case mgmProto.FirewallRule_IN: + case mgmProto.RuleDirection_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") - case mgmProto.FirewallRule_OUT: + case mgmProto.RuleDirection_OUT: rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") default: return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") @@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules( return append(rules, rule...), nil } -// getRuleID() returns unique ID for the rule based on its parameters. -func (d *DefaultManager) getRuleID( +// getPeerRuleID() returns unique ID for the rule based on its parameters. +func (d *DefaultManager) getPeerRuleID( ip net.IP, proto firewall.Protocol, direction int, port *firewall.Port, action firewall.Action, comment string, -) string { +) id.RuleID { idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment if port != nil { idStr += port.String() } - return hex.EncodeToString(md5.New().Sum([]byte(idStr))) + return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type @@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID( // but other has port definitions or has drop policy. func (d *DefaultManager) squashAcceptRules( networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { +) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { totalIPs := 0 for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for range p.AllowedIps { @@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules( } } - type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int + type protoMatch map[mgmProto.RuleProtocol]map[string]int in := protoMatch{} out := protoMatch{} // trace which type of protocols was squashed squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{} + squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} // this function we use to do calculation, can we squash the rules by protocol or not. // We summ amount of Peers IP for given protocol we found in original rules list. @@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules( // // We zeroed this to notify squash function that this protocol can't be squashed. addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { - drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != "" + drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" if drop { protocols[r.Protocol] = map[string]int{} return @@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules( for i, r := range networkMap.FirewallRules { // calculate squash for different directions - if r.Direction == mgmProto.FirewallRule_IN { + if r.Direction == mgmProto.RuleDirection_IN { addRuleToCalculationMap(i, r, in) } else { addRuleToCalculationMap(i, r, out) @@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules( // order of squashing by protocol is important // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.FirewallRuleProtocol{ - mgmProto.FirewallRule_ALL, - mgmProto.FirewallRule_ICMP, - mgmProto.FirewallRule_TCP, - mgmProto.FirewallRule_UDP, + protocolOrders := []mgmProto.RuleProtocol{ + mgmProto.RuleProtocol_ALL, + mgmProto.RuleProtocol_ICMP, + mgmProto.RuleProtocol_TCP, + mgmProto.RuleProtocol_UDP, } - squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { + squash := func(matches protoMatch, direction mgmProto.RuleDirection) { for _, protocol := range protocolOrders { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { // don't squash if : @@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules( squashedRules = append(squashedRules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", Direction: direction, - Action: mgmProto.FirewallRule_ACCEPT, + Action: mgmProto.RuleAction_ACCEPT, Protocol: protocol, }) squashedProtocols[protocol] = struct{}{} - if protocol == mgmProto.FirewallRule_ALL { + if protocol == mgmProto.RuleProtocol_ALL { // if we have ALL traffic type squashed rule // it allows all other type of traffic, so we can stop processing break @@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules( } } - squash(in, mgmProto.FirewallRule_IN) - squash(out, mgmProto.FirewallRule_OUT) + squash(in, mgmProto.RuleDirection_IN) + squash(out, mgmProto.RuleDirection_OUT) // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { return squashedRules, squashedProtocols } @@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) } -func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) { +func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { log.Debugf("rollback ACL to previous state") for _, rules := range newRulePairs { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { + if err := d.firewall.DeletePeerRule(rule); err != nil { log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) } } } } -func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) { +func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { switch protocol { - case mgmProto.FirewallRule_TCP: + case mgmProto.RuleProtocol_TCP: return firewall.ProtocolTCP, nil - case mgmProto.FirewallRule_UDP: + case mgmProto.RuleProtocol_UDP: return firewall.ProtocolUDP, nil - case mgmProto.FirewallRule_ICMP: + case mgmProto.RuleProtocol_ICMP: return firewall.ProtocolICMP, nil - case mgmProto.FirewallRule_ALL: + case mgmProto.RuleProtocol_ALL: return firewall.ProtocolALL, nil default: return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) @@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil } -func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) { +func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { switch action { - case mgmProto.FirewallRule_ACCEPT: + case mgmProto.RuleAction_ACCEPT: return firewall.ActionAccept, nil - case mgmProto.FirewallRule_DROP: + case mgmProto.RuleAction_DROP: return firewall.ActionDrop, nil default: return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) } } + +func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { + if portInfo == nil { + return nil + } + + if portInfo.GetPort() != 0 { + return &firewall.Port{ + Values: []int{int(portInfo.GetPort())}, + } + } + + if portInfo.GetRange() != nil { + return &firewall.Port{ + IsRange: true, + Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, + } + } + + return nil +} + +func getDefault(prefix netip.Prefix) netip.Prefix { + if prefix.Addr().Is6() { + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) +} diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 494d54bf256..eec3d3b8cf1 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: "80", }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_UDP, Port: "53", }, }, @@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("firewall rules not applied: %v", acl.rulesPairs) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) return } }) t.Run("add extra rules", func(t *testing.T) { existedPairs := map[string]struct{}{} - for id := range acl.rulesPairs { - existedPairs[id] = struct{}{} + for id := range acl.peerRulesPairs { + existedPairs[id.GetRuleID()] = struct{}{} } // remove first rule @@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules, &mgmProto.FirewallRule{ PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_ICMP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_ICMP, }, ) acl.ApplyFiltering(networkMap) // we should have one old and one new rule in the existed rules - if len(acl.rulesPairs) != 2 { + if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied") return } // check that old rule was removed previousCount := 0 - for id := range acl.rulesPairs { - if _, ok := existedPairs[id]; ok { + for id := range acl.peerRulesPairs { + if _, ok := existedPairs[id.GetRuleID()]; ok { previousCount++ } } @@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 { - t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs)) + if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return } }) @@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, }, } @@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_IN: + case r.Direction != mgmProto.RuleDirection_IN: t.Errorf("direction should be IN, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_OUT: + case r.Direction != mgmProto.RuleDirection_OUT: t.Errorf("direction should be OUT, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 4 { - t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 4 { + t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) return } } diff --git a/client/internal/engine.go b/client/internal/engine.go index 463507ad89a..998cbce2de1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -704,6 +704,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } + // Apply ACLs in the beginning to avoid security leaks + if e.acl != nil { + e.acl.ApplyFiltering(networkMap) + } + protoRoutes := networkMap.GetRoutes() if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} @@ -770,10 +775,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } - if e.acl != nil { - e.acl.ApplyFiltering(networkMap) - } - e.networkSerial = serial // Test received (upstream) servers for availability right away instead of upon usage. diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 5897031e718..e86a5281086 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti var merr *multierror.Error for _, prefix := range prefixes { - if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil { + if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil { merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err)) continue } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index cdfd322bd5b..d97fe631fe4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -87,10 +87,10 @@ func NewManager( } dm.routeRefCounter = refcounter.New( - func(prefix netip.Prefix, _ any) (any, error) { - return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + func(prefix netip.Prefix, _ struct{}) (struct{}, error) { + return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) }, - func(prefix netip.Prefix, _ any) error { + func(prefix netip.Prefix, _ struct{}) error { return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) }, ) diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f1d696ad95b..65ea0f708ea 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -3,7 +3,8 @@ package refcounter import ( "errors" "fmt" - "net/netip" + "runtime" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,118 +13,153 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" ) -// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix. +const logLevel = log.TraceLevel + +// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key. var ErrIgnore = errors.New("ignore") +// Ref holds the reference count and associated data for a key. type Ref[O any] struct { Count int Out O } -type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error) -type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error - -type Counter[I, O any] struct { - // refCountMap keeps track of the reference Ref for prefixes - refCountMap map[netip.Prefix]Ref[O] +// AddFunc is the function type for adding a new key. +// Key is the type of the key (e.g., netip.Prefix). +type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error) + +// RemoveFunc is the function type for removing a key. +type RemoveFunc[Key, O any] func(key Key, out O) error + +// Counter is a generic reference counter for managing keys and their associated data. +// Key: The type of the key (e.g., netip.Prefix, string). +// +// I: The input type for the AddFunc. It is the input type for additional data needed +// when adding a key, it is passed as the second argument to AddFunc. +// +// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc. +// It is stored and passed to RemoveFunc when the reference count reaches 0. +// +// The types can be aliased to a specific type using the following syntax: +// +// type RouteRefCounter = Counter[netip.Prefix, any, any] +type Counter[Key comparable, I, O any] struct { + // refCountMap keeps track of the reference Ref for keys + refCountMap map[Key]Ref[O] refCountMu sync.Mutex - // idMap keeps track of the prefixes associated with an ID for removal - idMap map[string][]netip.Prefix + // idMap keeps track of the keys associated with an ID for removal + idMap map[string][]Key idMu sync.Mutex - add AddFunc[I, O] - remove RemoveFunc[I, O] + add AddFunc[Key, I, O] + remove RemoveFunc[Key, O] } -// New creates a new Counter instance -func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] { - return &Counter[I, O]{ - refCountMap: map[netip.Prefix]Ref[O]{}, - idMap: map[string][]netip.Prefix{}, +// New creates a new Counter instance. +// Usage example: +// +// counter := New[netip.Prefix, string, string]( +// func(key netip.Prefix, in string) (out string, err error) { ... }, +// func(key netip.Prefix, out string) error { ... },` +// ) +func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] { + return &Counter[Key, I, O]{ + refCountMap: map[Key]Ref[O]{}, + idMap: map[string][]Key{}, add: add, remove: remove, } } -// Increment increments the reference count for the given prefix. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) { +// Get retrieves the current reference count and associated data for a key. +// If the key doesn't exist, it returns a zero value Ref and false. +func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref := rm.refCountMap[prefix] - log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + ref, ok := rm.refCountMap[key] + return ref, ok +} - // Call AddFunc only if it's a new prefix +// Increment increments the reference count for the given key. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + + ref := rm.refCountMap[key] + logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) + + // Call AddFunc only if it's a new key if ref.Count == 0 { - log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out) - out, err := rm.add(prefix, in) + logCallerF("Calling add for key %v", key) + out, err := rm.add(key, in) if errors.Is(err, ErrIgnore) { return ref, nil } if err != nil { - return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err) + return ref, fmt.Errorf("failed to add for key %v: %w", key, err) } ref.Out = out } ref.Count++ - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref return ref, nil } -// IncrementWithID increments the reference count for the given prefix and groups it under the given ID. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) { +// IncrementWithID increments the reference count for the given key and groups it under the given ID. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { rm.idMu.Lock() defer rm.idMu.Unlock() - ref, err := rm.Increment(prefix, in) + ref, err := rm.Increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } - rm.idMap[id] = append(rm.idMap[id], prefix) + rm.idMap[id] = append(rm.idMap[id], key) return ref, nil } -// Decrement decrements the reference count for the given prefix. +// Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) { +func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref, ok := rm.refCountMap[prefix] + ref, ok := rm.refCountMap[key] if !ok { - log.Tracef("No reference found for prefix %s", prefix) + logCallerF("No reference found for key %v", key) return ref, nil } - log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out) if ref.Count == 1 { - log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out) - if err := rm.remove(prefix, ref.Out); err != nil { - return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err) + logCallerF("Calling remove for key %v", key) + if err := rm.remove(key, ref.Out); err != nil { + return ref, fmt.Errorf("remove for key %v: %w", key, err) } - delete(rm.refCountMap, prefix) + delete(rm.refCountMap, key) } else { ref.Count-- - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref } return ref, nil } -// DecrementWithID decrements the reference count for all prefixes associated with the given ID. +// DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) DecrementWithID(id string) error { +func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for _, prefix := range rm.idMap[id] { - if _, err := rm.Decrement(prefix); err != nil { + for _, key := range rm.idMap[id] { + if _, err := rm.Decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error { return nberrors.FormatErrorOrNil(merr) } -// Flush removes all references and calls RemoveFunc for each prefix. -func (rm *Counter[I, O]) Flush() error { +// Flush removes all references and calls RemoveFunc for each key. +func (rm *Counter[Key, I, O]) Flush() error { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for prefix := range rm.refCountMap { - log.Tracef("Removing for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.remove(prefix, ref.Out); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err)) + for key := range rm.refCountMap { + logCallerF("Calling remove for key %v", key) + ref := rm.refCountMap[key] + if err := rm.remove(key, ref.Out); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err)) } } - rm.refCountMap = map[netip.Prefix]Ref[O]{} - rm.idMap = map[string][]netip.Prefix{} + clear(rm.refCountMap) + clear(rm.idMap) return nberrors.FormatErrorOrNil(merr) } + +// Clear removes all references without calling RemoveFunc. +func (rm *Counter[Key, I, O]) Clear() { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + clear(rm.refCountMap) + clear(rm.idMap) +} + +func getCallerInfo(depth int, maxDepth int) (string, bool) { + if depth >= maxDepth { + return "", false + } + + pc, _, _, ok := runtime.Caller(depth) + if !ok { + return "", false + } + + if details := runtime.FuncForPC(pc); details != nil { + name := details.Name() + + lastDotIndex := strings.LastIndex(name, "/") + if lastDotIndex != -1 { + name = name[lastDotIndex+1:] + } + + if strings.HasPrefix(name, "refcounter.") { + // +2 to account for recursion + return getCallerInfo(depth+2, maxDepth) + } + + return name, true + } + + return "", false +} + +// logCaller logs a message with the package name and method of the function that called the current function. +func logCallerF(format string, args ...interface{}) { + if log.GetLevel() < logLevel { + return + } + + if callerName, ok := getCallerInfo(3, 18); ok { + format = fmt.Sprintf("[%s] %s", callerName, format) + } + + log.StandardLogger().Logf(logLevel, format, args...) +} diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go index 6753b64efe0..aadac3e25ab 100644 --- a/client/internal/routemanager/refcounter/types.go +++ b/client/internal/routemanager/refcounter/types.go @@ -1,7 +1,9 @@ package refcounter +import "net/netip" + // RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement -type RouteRefCounter = Counter[any, any] +type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}] // AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement -type AllowedIPsRefCounter = Counter[string, string] +type AllowedIPsRefCounter = Counter[netip.Prefix, string, string] diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 43a266cd259..1d1a4b0633e 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { return fmt.Errorf("remove routing rules: %w", err) } @@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.InsertRoutingRules(routerPair) + err = m.firewall.AddNatRule(routerPair) if err != nil { return fmt.Errorf("insert routing rules: %w", err) } @@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() { continue } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } @@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { // TODO: add ipv6 source := getDefaultPrefix(route.Network) - destination := route.Network.Masked().String() + destination := route.Network.Masked() if route.IsDynamic() { - // TODO: add ipv6 - destination = "0.0.0.0/0" + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination) } return firewall.RouterPair{ - ID: string(route.ID), - Source: source.String(), + ID: route.ID, + Source: source, Destination: destination, Masquerade: route.Masquerade, }, nil diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index 88cca522aed..98c34dbeed9 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -30,7 +30,7 @@ func (r *Route) String() string { } func (r *Route) AddRoute(context.Context) error { - _, err := r.routeRefCounter.Increment(r.route.Network, nil) + _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) return err } diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index ae27b012383..10944c1e22d 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -15,7 +15,7 @@ type Nexthop struct { Intf *net.Interface } -type ExclusionCounter = refcounter.Counter[any, Nexthop] +type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { refCounter *ExclusionCounter diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d76824c10f0..90f06ba7835 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn } refCounter := refcounter.New( - func(prefix netip.Prefix, _ any) (Nexthop, error) { + func(prefix netip.Prefix, _ struct{}) (Nexthop, error) { initialNexthop := initialNextHopV4 if prefix.Addr().Is6() { initialNexthop = initialNextHopV6 @@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("convert ip to prefix: %w", err) } - if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil { + if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 48f048c4c25..672b2a10228 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v4.23.4 // source: management.proto package proto @@ -21,249 +21,249 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type HostConfig_Protocol int32 +type RuleProtocol int32 const ( - HostConfig_UDP HostConfig_Protocol = 0 - HostConfig_TCP HostConfig_Protocol = 1 - HostConfig_HTTP HostConfig_Protocol = 2 - HostConfig_HTTPS HostConfig_Protocol = 3 - HostConfig_DTLS HostConfig_Protocol = 4 + RuleProtocol_UNKNOWN RuleProtocol = 0 + RuleProtocol_ALL RuleProtocol = 1 + RuleProtocol_TCP RuleProtocol = 2 + RuleProtocol_UDP RuleProtocol = 3 + RuleProtocol_ICMP RuleProtocol = 4 ) -// Enum value maps for HostConfig_Protocol. +// Enum value maps for RuleProtocol. var ( - HostConfig_Protocol_name = map[int32]string{ - 0: "UDP", - 1: "TCP", - 2: "HTTP", - 3: "HTTPS", - 4: "DTLS", + RuleProtocol_name = map[int32]string{ + 0: "UNKNOWN", + 1: "ALL", + 2: "TCP", + 3: "UDP", + 4: "ICMP", } - HostConfig_Protocol_value = map[string]int32{ - "UDP": 0, - "TCP": 1, - "HTTP": 2, - "HTTPS": 3, - "DTLS": 4, + RuleProtocol_value = map[string]int32{ + "UNKNOWN": 0, + "ALL": 1, + "TCP": 2, + "UDP": 3, + "ICMP": 4, } ) -func (x HostConfig_Protocol) Enum() *HostConfig_Protocol { - p := new(HostConfig_Protocol) +func (x RuleProtocol) Enum() *RuleProtocol { + p := new(RuleProtocol) *p = x return p } -func (x HostConfig_Protocol) String() string { +func (x RuleProtocol) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } -func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { +func (RuleProtocol) Descriptor() protoreflect.EnumDescriptor { return file_management_proto_enumTypes[0].Descriptor() } -func (HostConfig_Protocol) Type() protoreflect.EnumType { +func (RuleProtocol) Type() protoreflect.EnumType { return &file_management_proto_enumTypes[0] } -func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { +func (x RuleProtocol) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } -// Deprecated: Use HostConfig_Protocol.Descriptor instead. -func (HostConfig_Protocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{13, 0} +// Deprecated: Use RuleProtocol.Descriptor instead. +func (RuleProtocol) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{0} } -type DeviceAuthorizationFlowProvider int32 +type RuleDirection int32 const ( - DeviceAuthorizationFlow_HOSTED DeviceAuthorizationFlowProvider = 0 + RuleDirection_IN RuleDirection = 0 + RuleDirection_OUT RuleDirection = 1 ) -// Enum value maps for DeviceAuthorizationFlowProvider. +// Enum value maps for RuleDirection. var ( - DeviceAuthorizationFlowProvider_name = map[int32]string{ - 0: "HOSTED", + RuleDirection_name = map[int32]string{ + 0: "IN", + 1: "OUT", } - DeviceAuthorizationFlowProvider_value = map[string]int32{ - "HOSTED": 0, + RuleDirection_value = map[string]int32{ + "IN": 0, + "OUT": 1, } ) -func (x DeviceAuthorizationFlowProvider) Enum() *DeviceAuthorizationFlowProvider { - p := new(DeviceAuthorizationFlowProvider) +func (x RuleDirection) Enum() *RuleDirection { + p := new(RuleDirection) *p = x return p } -func (x DeviceAuthorizationFlowProvider) String() string { +func (x RuleDirection) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } -func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { +func (RuleDirection) Descriptor() protoreflect.EnumDescriptor { return file_management_proto_enumTypes[1].Descriptor() } -func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { +func (RuleDirection) Type() protoreflect.EnumType { return &file_management_proto_enumTypes[1] } -func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { +func (x RuleDirection) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } -// Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. -func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{21, 0} +// Deprecated: Use RuleDirection.Descriptor instead. +func (RuleDirection) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{1} } -type FirewallRuleDirection int32 +type RuleAction int32 const ( - FirewallRule_IN FirewallRuleDirection = 0 - FirewallRule_OUT FirewallRuleDirection = 1 + RuleAction_ACCEPT RuleAction = 0 + RuleAction_DROP RuleAction = 1 ) -// Enum value maps for FirewallRuleDirection. +// Enum value maps for RuleAction. var ( - FirewallRuleDirection_name = map[int32]string{ - 0: "IN", - 1: "OUT", + RuleAction_name = map[int32]string{ + 0: "ACCEPT", + 1: "DROP", } - FirewallRuleDirection_value = map[string]int32{ - "IN": 0, - "OUT": 1, + RuleAction_value = map[string]int32{ + "ACCEPT": 0, + "DROP": 1, } ) -func (x FirewallRuleDirection) Enum() *FirewallRuleDirection { - p := new(FirewallRuleDirection) +func (x RuleAction) Enum() *RuleAction { + p := new(RuleAction) *p = x return p } -func (x FirewallRuleDirection) String() string { +func (x RuleAction) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } -func (FirewallRuleDirection) Descriptor() protoreflect.EnumDescriptor { +func (RuleAction) Descriptor() protoreflect.EnumDescriptor { return file_management_proto_enumTypes[2].Descriptor() } -func (FirewallRuleDirection) Type() protoreflect.EnumType { +func (RuleAction) Type() protoreflect.EnumType { return &file_management_proto_enumTypes[2] } -func (x FirewallRuleDirection) Number() protoreflect.EnumNumber { +func (x RuleAction) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } -// Deprecated: Use FirewallRuleDirection.Descriptor instead. -func (FirewallRuleDirection) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 0} +// Deprecated: Use RuleAction.Descriptor instead. +func (RuleAction) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{2} } -type FirewallRuleAction int32 +type HostConfig_Protocol int32 const ( - FirewallRule_ACCEPT FirewallRuleAction = 0 - FirewallRule_DROP FirewallRuleAction = 1 + HostConfig_UDP HostConfig_Protocol = 0 + HostConfig_TCP HostConfig_Protocol = 1 + HostConfig_HTTP HostConfig_Protocol = 2 + HostConfig_HTTPS HostConfig_Protocol = 3 + HostConfig_DTLS HostConfig_Protocol = 4 ) -// Enum value maps for FirewallRuleAction. +// Enum value maps for HostConfig_Protocol. var ( - FirewallRuleAction_name = map[int32]string{ - 0: "ACCEPT", - 1: "DROP", + HostConfig_Protocol_name = map[int32]string{ + 0: "UDP", + 1: "TCP", + 2: "HTTP", + 3: "HTTPS", + 4: "DTLS", } - FirewallRuleAction_value = map[string]int32{ - "ACCEPT": 0, - "DROP": 1, + HostConfig_Protocol_value = map[string]int32{ + "UDP": 0, + "TCP": 1, + "HTTP": 2, + "HTTPS": 3, + "DTLS": 4, } ) -func (x FirewallRuleAction) Enum() *FirewallRuleAction { - p := new(FirewallRuleAction) +func (x HostConfig_Protocol) Enum() *HostConfig_Protocol { + p := new(HostConfig_Protocol) *p = x return p } -func (x FirewallRuleAction) String() string { +func (x HostConfig_Protocol) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } -func (FirewallRuleAction) Descriptor() protoreflect.EnumDescriptor { +func (HostConfig_Protocol) Descriptor() protoreflect.EnumDescriptor { return file_management_proto_enumTypes[3].Descriptor() } -func (FirewallRuleAction) Type() protoreflect.EnumType { +func (HostConfig_Protocol) Type() protoreflect.EnumType { return &file_management_proto_enumTypes[3] } -func (x FirewallRuleAction) Number() protoreflect.EnumNumber { +func (x HostConfig_Protocol) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } -// Deprecated: Use FirewallRuleAction.Descriptor instead. -func (FirewallRuleAction) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 1} +// Deprecated: Use HostConfig_Protocol.Descriptor instead. +func (HostConfig_Protocol) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{13, 0} } -type FirewallRuleProtocol int32 +type DeviceAuthorizationFlowProvider int32 const ( - FirewallRule_UNKNOWN FirewallRuleProtocol = 0 - FirewallRule_ALL FirewallRuleProtocol = 1 - FirewallRule_TCP FirewallRuleProtocol = 2 - FirewallRule_UDP FirewallRuleProtocol = 3 - FirewallRule_ICMP FirewallRuleProtocol = 4 + DeviceAuthorizationFlow_HOSTED DeviceAuthorizationFlowProvider = 0 ) -// Enum value maps for FirewallRuleProtocol. +// Enum value maps for DeviceAuthorizationFlowProvider. var ( - FirewallRuleProtocol_name = map[int32]string{ - 0: "UNKNOWN", - 1: "ALL", - 2: "TCP", - 3: "UDP", - 4: "ICMP", + DeviceAuthorizationFlowProvider_name = map[int32]string{ + 0: "HOSTED", } - FirewallRuleProtocol_value = map[string]int32{ - "UNKNOWN": 0, - "ALL": 1, - "TCP": 2, - "UDP": 3, - "ICMP": 4, + DeviceAuthorizationFlowProvider_value = map[string]int32{ + "HOSTED": 0, } ) -func (x FirewallRuleProtocol) Enum() *FirewallRuleProtocol { - p := new(FirewallRuleProtocol) +func (x DeviceAuthorizationFlowProvider) Enum() *DeviceAuthorizationFlowProvider { + p := new(DeviceAuthorizationFlowProvider) *p = x return p } -func (x FirewallRuleProtocol) String() string { +func (x DeviceAuthorizationFlowProvider) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } -func (FirewallRuleProtocol) Descriptor() protoreflect.EnumDescriptor { +func (DeviceAuthorizationFlowProvider) Descriptor() protoreflect.EnumDescriptor { return file_management_proto_enumTypes[4].Descriptor() } -func (FirewallRuleProtocol) Type() protoreflect.EnumType { +func (DeviceAuthorizationFlowProvider) Type() protoreflect.EnumType { return &file_management_proto_enumTypes[4] } -func (x FirewallRuleProtocol) Number() protoreflect.EnumNumber { +func (x DeviceAuthorizationFlowProvider) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } -// Deprecated: Use FirewallRuleProtocol.Descriptor instead. -func (FirewallRuleProtocol) EnumDescriptor() ([]byte, []int) { - return file_management_proto_rawDescGZIP(), []int{31, 2} +// Deprecated: Use DeviceAuthorizationFlowProvider.Descriptor instead. +func (DeviceAuthorizationFlowProvider) EnumDescriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{21, 0} } type EncryptedMessage struct { @@ -1482,6 +1482,10 @@ type NetworkMap struct { FirewallRules []*FirewallRule `protobuf:"bytes,8,rep,name=FirewallRules,proto3" json:"FirewallRules,omitempty"` // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. FirewallRulesIsEmpty bool `protobuf:"varint,9,opt,name=firewallRulesIsEmpty,proto3" json:"firewallRulesIsEmpty,omitempty"` + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + RoutesFirewallRules []*RouteFirewallRule `protobuf:"bytes,10,rep,name=routesFirewallRules,proto3" json:"routesFirewallRules,omitempty"` + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + RoutesFirewallRulesIsEmpty bool `protobuf:"varint,11,opt,name=routesFirewallRulesIsEmpty,proto3" json:"routesFirewallRulesIsEmpty,omitempty"` } func (x *NetworkMap) Reset() { @@ -1579,6 +1583,20 @@ func (x *NetworkMap) GetFirewallRulesIsEmpty() bool { return false } +func (x *NetworkMap) GetRoutesFirewallRules() []*RouteFirewallRule { + if x != nil { + return x.RoutesFirewallRules + } + return nil +} + +func (x *NetworkMap) GetRoutesFirewallRulesIsEmpty() bool { + if x != nil { + return x.RoutesFirewallRulesIsEmpty + } + return false +} + // RemotePeerConfig represents a configuration of a remote peer. // The properties are used to configure WireGuard Peers sections type RemotePeerConfig struct { @@ -2487,11 +2505,11 @@ type FirewallRule struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` - Direction FirewallRuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.FirewallRuleDirection" json:"Direction,omitempty"` - Action FirewallRuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.FirewallRuleAction" json:"Action,omitempty"` - Protocol FirewallRuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.FirewallRuleProtocol" json:"Protocol,omitempty"` - Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"` + PeerIP string `protobuf:"bytes,1,opt,name=PeerIP,proto3" json:"PeerIP,omitempty"` + Direction RuleDirection `protobuf:"varint,2,opt,name=Direction,proto3,enum=management.RuleDirection" json:"Direction,omitempty"` + Action RuleAction `protobuf:"varint,3,opt,name=Action,proto3,enum=management.RuleAction" json:"Action,omitempty"` + Protocol RuleProtocol `protobuf:"varint,4,opt,name=Protocol,proto3,enum=management.RuleProtocol" json:"Protocol,omitempty"` + Port string `protobuf:"bytes,5,opt,name=Port,proto3" json:"Port,omitempty"` } func (x *FirewallRule) Reset() { @@ -2533,25 +2551,25 @@ func (x *FirewallRule) GetPeerIP() string { return "" } -func (x *FirewallRule) GetDirection() FirewallRuleDirection { +func (x *FirewallRule) GetDirection() RuleDirection { if x != nil { return x.Direction } - return FirewallRule_IN + return RuleDirection_IN } -func (x *FirewallRule) GetAction() FirewallRuleAction { +func (x *FirewallRule) GetAction() RuleAction { if x != nil { return x.Action } - return FirewallRule_ACCEPT + return RuleAction_ACCEPT } -func (x *FirewallRule) GetProtocol() FirewallRuleProtocol { +func (x *FirewallRule) GetProtocol() RuleProtocol { if x != nil { return x.Protocol } - return FirewallRule_UNKNOWN + return RuleProtocol_UNKNOWN } func (x *FirewallRule) GetPort() string { @@ -2663,6 +2681,236 @@ func (x *Checks) GetFiles() []string { return nil } +type PortInfo struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to PortSelection: + // + // *PortInfo_Port + // *PortInfo_Range_ + PortSelection isPortInfo_PortSelection `protobuf_oneof:"portSelection"` +} + +func (x *PortInfo) Reset() { + *x = PortInfo{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PortInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PortInfo) ProtoMessage() {} + +func (x *PortInfo) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[34] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PortInfo.ProtoReflect.Descriptor instead. +func (*PortInfo) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{34} +} + +func (m *PortInfo) GetPortSelection() isPortInfo_PortSelection { + if m != nil { + return m.PortSelection + } + return nil +} + +func (x *PortInfo) GetPort() uint32 { + if x, ok := x.GetPortSelection().(*PortInfo_Port); ok { + return x.Port + } + return 0 +} + +func (x *PortInfo) GetRange() *PortInfo_Range { + if x, ok := x.GetPortSelection().(*PortInfo_Range_); ok { + return x.Range + } + return nil +} + +type isPortInfo_PortSelection interface { + isPortInfo_PortSelection() +} + +type PortInfo_Port struct { + Port uint32 `protobuf:"varint,1,opt,name=port,proto3,oneof"` +} + +type PortInfo_Range_ struct { + Range *PortInfo_Range `protobuf:"bytes,2,opt,name=range,proto3,oneof"` +} + +func (*PortInfo_Port) isPortInfo_PortSelection() {} + +func (*PortInfo_Range_) isPortInfo_PortSelection() {} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // sourceRanges IP ranges of the routing peers. + SourceRanges []string `protobuf:"bytes,1,rep,name=sourceRanges,proto3" json:"sourceRanges,omitempty"` + // Action to be taken by the firewall when the rule is applicable. + Action RuleAction `protobuf:"varint,2,opt,name=action,proto3,enum=management.RuleAction" json:"action,omitempty"` + // Network prefix for the routed network. + Destination string `protobuf:"bytes,3,opt,name=destination,proto3" json:"destination,omitempty"` + // Protocol of the routed network. + Protocol RuleProtocol `protobuf:"varint,4,opt,name=protocol,proto3,enum=management.RuleProtocol" json:"protocol,omitempty"` + // Details about the port. + PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` + // IsDynamic indicates if the route is a DNS route. + IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` +} + +func (x *RouteFirewallRule) Reset() { + *x = RouteFirewallRule{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RouteFirewallRule) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RouteFirewallRule) ProtoMessage() {} + +func (x *RouteFirewallRule) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[35] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RouteFirewallRule.ProtoReflect.Descriptor instead. +func (*RouteFirewallRule) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{35} +} + +func (x *RouteFirewallRule) GetSourceRanges() []string { + if x != nil { + return x.SourceRanges + } + return nil +} + +func (x *RouteFirewallRule) GetAction() RuleAction { + if x != nil { + return x.Action + } + return RuleAction_ACCEPT +} + +func (x *RouteFirewallRule) GetDestination() string { + if x != nil { + return x.Destination + } + return "" +} + +func (x *RouteFirewallRule) GetProtocol() RuleProtocol { + if x != nil { + return x.Protocol + } + return RuleProtocol_UNKNOWN +} + +func (x *RouteFirewallRule) GetPortInfo() *PortInfo { + if x != nil { + return x.PortInfo + } + return nil +} + +func (x *RouteFirewallRule) GetIsDynamic() bool { + if x != nil { + return x.IsDynamic + } + return false +} + +type PortInfo_Range struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` + End uint32 `protobuf:"varint,2,opt,name=end,proto3" json:"end,omitempty"` +} + +func (x *PortInfo_Range) Reset() { + *x = PortInfo_Range{} + if protoimpl.UnsafeEnabled { + mi := &file_management_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PortInfo_Range) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PortInfo_Range) ProtoMessage() {} + +func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { + mi := &file_management_proto_msgTypes[36] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PortInfo_Range.ProtoReflect.Descriptor instead. +func (*PortInfo_Range) Descriptor() ([]byte, []int) { + return file_management_proto_rawDescGZIP(), []int{34, 0} +} + +func (x *PortInfo_Range) GetStart() uint32 { + if x != nil { + return x.Start + } + return 0 +} + +func (x *PortInfo_Range) GetEnd() uint32 { + if x != nil { + return x.End + } + return 0 +} + var File_management_proto protoreflect.FileDescriptor var file_management_proto_rawDesc = []byte{ @@ -2835,7 +3083,7 @@ var file_management_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xe2, 0x03, 0x0a, 0x0a, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, @@ -2866,184 +3114,219 @@ var file_management_proto_rawDesc = []byte{ 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, - 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, - 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, - 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, - 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, - 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, - 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, - 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, - 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, - 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, - 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, - 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, - 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, - 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, - 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, - 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, - 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, - 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, - 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, - 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, - 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, - 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, - 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, - 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, - 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, - 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, - 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, - 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, - 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, - 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, - 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, - 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, - 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, - 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, - 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, - 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xf0, 0x02, 0x0a, 0x0c, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, - 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, - 0x72, 0x49, 0x50, 0x12, 0x40, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x22, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x37, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1f, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, - 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x3d, - 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x21, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, - 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, - 0x74, 0x22, 0x1c, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, - 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x22, - 0x1e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, - 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x22, - 0x3c, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, - 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, - 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x22, 0x38, 0x0a, - 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, - 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, - 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, - 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, - 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, + 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, + 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, + 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, + 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, + 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, + 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, + 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, + 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, + 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, + 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, + 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, + 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, + 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, + 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, + 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, + 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, + 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, + 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, + 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, + 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, + 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, + 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, + 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, + 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, + 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, + 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, + 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, + 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, + 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, + 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, + 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, + 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, + 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, + 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, + 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, + 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, + 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, + 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, + 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, + 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, + 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, + 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, + 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, + 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, + 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, + 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, + 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, + 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, + 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, + 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, + 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, + 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, + 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, + 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, + 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, + 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, + 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, + 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, + 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, + 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, + 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, + 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, + 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, + 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, + 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, + 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, + 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, + 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, + 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, + 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, - 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, - 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, + 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3059,13 +3342,13 @@ func file_management_proto_rawDescGZIP() []byte { } var file_management_proto_enumTypes = make([]protoimpl.EnumInfo, 5) -var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_management_proto_msgTypes = make([]protoimpl.MessageInfo, 37) var file_management_proto_goTypes = []interface{}{ - (HostConfig_Protocol)(0), // 0: management.HostConfig.Protocol - (DeviceAuthorizationFlowProvider)(0), // 1: management.DeviceAuthorizationFlow.provider - (FirewallRuleDirection)(0), // 2: management.FirewallRule.direction - (FirewallRuleAction)(0), // 3: management.FirewallRule.action - (FirewallRuleProtocol)(0), // 4: management.FirewallRule.protocol + (RuleProtocol)(0), // 0: management.RuleProtocol + (RuleDirection)(0), // 1: management.RuleDirection + (RuleAction)(0), // 2: management.RuleAction + (HostConfig_Protocol)(0), // 3: management.HostConfig.Protocol + (DeviceAuthorizationFlowProvider)(0), // 4: management.DeviceAuthorizationFlow.provider (*EncryptedMessage)(nil), // 5: management.EncryptedMessage (*SyncRequest)(nil), // 6: management.SyncRequest (*SyncResponse)(nil), // 7: management.SyncResponse @@ -3100,7 +3383,10 @@ var file_management_proto_goTypes = []interface{}{ (*FirewallRule)(nil), // 36: management.FirewallRule (*NetworkAddress)(nil), // 37: management.NetworkAddress (*Checks)(nil), // 38: management.Checks - (*timestamppb.Timestamp)(nil), // 39: google.protobuf.Timestamp + (*PortInfo)(nil), // 39: management.PortInfo + (*RouteFirewallRule)(nil), // 40: management.RouteFirewallRule + (*PortInfo_Range)(nil), // 41: management.PortInfo.Range + (*timestamppb.Timestamp)(nil), // 42: google.protobuf.Timestamp } var file_management_proto_depIdxs = []int32{ 13, // 0: management.SyncRequest.meta:type_name -> management.PeerSystemMeta @@ -3118,12 +3404,12 @@ var file_management_proto_depIdxs = []int32{ 17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks - 39, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 42, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig 18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig - 0, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 3, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig @@ -3132,36 +3418,41 @@ var file_management_proto_depIdxs = []int32{ 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 24, // 29: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 1, // 30: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 29, // 31: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 29, // 32: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 34, // 33: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 32, // 34: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 33, // 35: management.CustomZone.Records:type_name -> management.SimpleRecord - 35, // 36: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 37: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction - 3, // 38: management.FirewallRule.Action:type_name -> management.FirewallRule.action - 4, // 39: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol - 5, // 40: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 41: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 16, // 42: management.ManagementService.GetServerKey:input_type -> management.Empty - 16, // 43: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 44: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 45: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 46: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 47: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 48: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 15, // 49: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 16, // 50: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 51: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 52: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 16, // 53: management.ManagementService.SyncMeta:output_type -> management.Empty - 47, // [47:54] is the sub-list for method output_type - 40, // [40:47] is the sub-list for method input_type - 40, // [40:40] is the sub-list for extension type_name - 40, // [40:40] is the sub-list for extension extendee - 0, // [0:40] is the sub-list for field type_name + 40, // 29: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 24, // 30: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 4, // 31: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 29, // 32: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 29, // 33: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 34, // 34: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 32, // 35: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 33, // 36: management.CustomZone.Records:type_name -> management.SimpleRecord + 35, // 37: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 38: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 39: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 40: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 41, // 41: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 42: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 43: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 39, // 44: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 5, // 45: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 46: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 16, // 47: management.ManagementService.GetServerKey:input_type -> management.Empty + 16, // 48: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 49: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 50: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 51: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 52: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 53: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 15, // 54: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 16, // 55: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 56: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 57: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 16, // 58: management.ManagementService.SyncMeta:output_type -> management.Empty + 52, // [52:59] is the sub-list for method output_type + 45, // [45:52] is the sub-list for method input_type + 45, // [45:45] is the sub-list for extension type_name + 45, // [45:45] is the sub-list for extension extendee + 0, // [0:45] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -3578,6 +3869,46 @@ func file_management_proto_init() { return nil } } + file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo_Range); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_management_proto_msgTypes[34].OneofWrappers = []interface{}{ + (*PortInfo_Port)(nil), + (*PortInfo_Range_)(nil), } type x struct{} out := protoimpl.TypeBuilder{ @@ -3585,7 +3916,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 34, + NumMessages: 37, NumExtensions: 0, NumServices: 1, }, diff --git a/management/proto/management.proto b/management/proto/management.proto index c5646820f96..fe6a828b1e5 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -254,6 +254,12 @@ message NetworkMap { // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. bool firewallRulesIsEmpty = 9; + + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + repeated RouteFirewallRule routesFirewallRules = 10; + + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + bool routesFirewallRulesIsEmpty = 11; } // RemotePeerConfig represents a configuration of a remote peer. @@ -384,29 +390,32 @@ message NameServer { int64 Port = 3; } +enum RuleProtocol { + UNKNOWN = 0; + ALL = 1; + TCP = 2; + UDP = 3; + ICMP = 4; +} + +enum RuleDirection { + IN = 0; + OUT = 1; +} + +enum RuleAction { + ACCEPT = 0; + DROP = 1; +} + + // FirewallRule represents a firewall rule message FirewallRule { string PeerIP = 1; - direction Direction = 2; - action Action = 3; - protocol Protocol = 4; + RuleDirection Direction = 2; + RuleAction Action = 3; + RuleProtocol Protocol = 4; string Port = 5; - - enum direction { - IN = 0; - OUT = 1; - } - enum action { - ACCEPT = 0; - DROP = 1; - } - enum protocol { - UNKNOWN = 0; - ALL = 1; - TCP = 2; - UDP = 3; - ICMP = 4; - } } message NetworkAddress { @@ -415,5 +424,40 @@ message NetworkAddress { } message Checks { - repeated string Files= 1; + repeated string Files = 1; } + + +message PortInfo { + oneof portSelection { + uint32 port = 1; + Range range = 2; + } + + message Range { + uint32 start = 1; + uint32 end = 2; + } +} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +message RouteFirewallRule { + // sourceRanges IP ranges of the routing peers. + repeated string sourceRanges = 1; + + // Action to be taken by the firewall when the rule is applicable. + RuleAction action = 2; + + // Network prefix for the routed network. + string destination = 3; + + // Protocol of the routed network. + RuleProtocol protocol = 4; + + // Details about the port. + PortInfo portInfo = 5; + + // IsDynamic indicates if the route is a DNS route. + bool isDynamic = 6; +} + diff --git a/management/server/account.go b/management/server/account.go index 710b6f62f35..d5e8c8cf8b1 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) @@ -460,6 +460,7 @@ func (a *Account) GetPeerNetworkMap( } routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -483,6 +484,7 @@ func (a *Account) GetPeerNetworkMap( DNSConfig: dnsUpdate, OfflinePeers: expiredPeers, FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, } if metrics != nil { diff --git a/management/server/account_test.go b/management/server/account_test.go index 303261bead6..e554ae493ea 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1599,9 +1599,10 @@ func TestAccount_Copy(t *testing.T) { }, Routes: map[route.ID]*route.Route{ "route1": { - ID: "route1", - PeerGroups: []string{}, - Groups: []string{"group1"}, + ID: "route1", + PeerGroups: []string{}, + Groups: []string{"group1"}, + AccessControlGroups: []string{}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index cda3bc7482b..4c4ef6c3ca1 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -596,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + return response } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2463f830e8b..fd0343e97bb 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -727,17 +727,39 @@ components: enum: ["all", "tcp", "udp", "icmp"] example: "tcp" ports: - description: Policy rule affected ports or it ranges list + description: Policy rule affected ports type: array items: type: string example: "80" + port_ranges: + description: Policy rule affected ports ranges list + type: array + items: + $ref: '#/components/schemas/RulePortRange' required: - name - enabled - bidirectional - protocol - action + + RulePortRange: + description: Policy rule affected ports range + type: object + properties: + start: + description: The starting port of the range + type: integer + example: 80 + end: + description: The ending port of the range + type: integer + example: 320 + required: + - start + - end + PolicyRuleUpdate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -1106,6 +1128,12 @@ components: description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore type: boolean example: true + access_control_groups: + description: Access control group identifier associated with route. + type: array + items: + type: string + example: "chacbco6lnnbn6cg5s91" required: - id - description diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index b219d38fdd2..570ec03c5bc 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -780,7 +780,10 @@ type PolicyRule struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -816,7 +819,10 @@ type PolicyRuleMinimum struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -852,7 +858,10 @@ type PolicyRuleUpdate struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -935,6 +944,9 @@ type ProcessCheck struct { // Route defines model for Route. type Route struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` @@ -977,6 +989,9 @@ type Route struct { // RouteRequest defines model for RouteRequest. type RouteRequest struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` @@ -1011,6 +1026,15 @@ type RouteRequest struct { PeerGroups *[]string `json:"peer_groups,omitempty"` } +// RulePortRange Policy rule affected ports range +type RulePortRange struct { + // End The ending port of the range + End int `json:"end"` + + // Start The starting port of the range + Start int `json:"start"` +} + // SetupKey defines model for SetupKey. type SetupKey struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 225d7e1f30c..73f3803b5ed 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -172,6 +172,11 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } + if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w) + return + } + if rule.Ports != nil && len(*rule.Ports) != 0 { for _, v := range *rule.Ports { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { @@ -182,10 +187,23 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID } } + if rule.PortRanges != nil && len(*rule.PortRanges) != 0 { + for _, portRange := range *rule.PortRanges { + if portRange.Start < 1 || portRange.End > 65535 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) + return + } + pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + Start: uint16(portRange.Start), + End: uint16(portRange.End), + }) + } + } + // validate policy object switch pr.Protocol { case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: - if len(pr.Ports) != 0 { + if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return } @@ -194,7 +212,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: - if !pr.Bidirectional && len(pr.Ports) == 0 { + if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } @@ -320,6 +338,17 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.Ports = &portsCopy } + if len(r.PortRanges) != 0 { + portRanges := make([]api.RulePortRange, 0, len(r.PortRanges)) + for _, portRange := range r.PortRanges { + portRanges = append(portRanges, api.RulePortRange{ + End: int(portRange.End), + Start: int(portRange.Start), + }) + } + rule.PortRanges = &portRanges + } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 0932e64455e..ce4edee4f16 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -117,9 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } + var accessControlGroupIds []string + if req.AccessControlGroups != nil { + accessControlGroupIds = *req.AccessControlGroups + } + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, userID, req.KeepRoute, - ) + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute) + if err != nil { util.WriteError(r.Context(), err, w) return @@ -233,6 +238,10 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } + if req.AccessControlGroups != nil { + newRoute.AccessControlGroups = *req.AccessControlGroups + } + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) @@ -326,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { if len(serverRoute.PeerGroups) > 0 { route.PeerGroups = &serverRoute.PeerGroups } + if len(serverRoute.AccessControlGroups) > 0 { + route.AccessControlGroups = &serverRoute.AccessControlGroups + } return route, nil } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 2c367cac399..83bd7004d1c 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -105,7 +105,7 @@ func initRoutesTestData() *RoutesHandler { } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -119,18 +119,19 @@ func initRoutesTestData() *RoutesHandler { } return &route.Route{ - ID: existingRouteID, - NetID: netID, - Peer: peerID, - PeerGroups: peerGroups, - Network: prefix, - Domains: domains, - NetworkType: networkType, - Description: description, - Masquerade: masquerade, - Enabled: enabled, - Groups: groups, - KeepRoute: keepRoute, + ID: existingRouteID, + NetID: netID, + Peer: peerID, + PeerGroups: peerGroups, + Network: prefix, + Domains: domains, + NetworkType: networkType, + Description: description, + Masquerade: masquerade, + Enabled: enabled, + Groups: groups, + KeepRoute: keepRoute, + AccessControlGroups: accessControlGroups, }, nil }, SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { @@ -268,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) { Groups: []string{existingGroupID}, }, }, + { + name: "POST OK With Access Control Groups", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: toPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + AccessControlGroups: &[]string{existingGroupID}, + }, + }, { name: "POST Non Linux Peer", requestType: http.MethodPost, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index df12ec1c437..b399be82288 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -58,7 +58,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -367,7 +367,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, if am.DeleteRuleFunc != nil { return am.DeleteRuleFunc(ctx, accountID, ruleID, userID) } - return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") + return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented") } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface @@ -442,9 +442,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/network.go b/management/server/network.go index 0e7d753a73d..a5b188b4610 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -26,12 +26,13 @@ const ( ) type NetworkMap struct { - Peers []*nbpeer.Peer - Network *Network - Routes []*route.Route - DNSConfig nbdns.Config - OfflinePeers []*nbpeer.Peer - FirewallRules []*FirewallRule + Peers []*nbpeer.Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*nbpeer.Peer + FirewallRules []*FirewallRule + RoutesFirewallRules []*RouteFirewallRule } type Network struct { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index d329e04bc46..387adb91daf 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -646,7 +646,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { }) } - } func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { @@ -991,9 +990,9 @@ func TestToSyncResponse(t *testing.T) { // assert network map Firewall assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) - assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction) - assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action) - assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol) + assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction) + assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action) + assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol) assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port) // assert posture checks assert.Equal(t, 1, len(response.Checks)) diff --git a/management/server/policy.go b/management/server/policy.go index 5d07ba8f8a0..75647de449b 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -76,6 +76,12 @@ type PolicyUpdateOperation struct { Values []string } +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + // PolicyRule is the metadata of the policy type PolicyRule struct { // ID of the policy rule @@ -110,6 +116,9 @@ type PolicyRule struct { // Ports or it ranges list Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` } // Copy returns a copy of a policy rule @@ -125,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule { Bidirectional: pm.Bidirectional, Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) return rule } @@ -445,36 +456,17 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli return nil } -func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(update)) - for i := range update { - direction := proto.FirewallRule_IN - if update[i].Direction == firewallRuleDirectionOUT { - direction = proto.FirewallRule_OUT - } - action := proto.FirewallRule_ACCEPT - if update[i].Action == string(PolicyTrafficActionDrop) { - action = proto.FirewallRule_DROP - } - - protocol := proto.FirewallRule_UNKNOWN - switch PolicyRuleProtocolType(update[i].Protocol) { - case PolicyRuleProtocolALL: - protocol = proto.FirewallRule_ALL - case PolicyRuleProtocolTCP: - protocol = proto.FirewallRule_TCP - case PolicyRuleProtocolUDP: - protocol = proto.FirewallRule_UDP - case PolicyRuleProtocolICMP: - protocol = proto.FirewallRule_ICMP - } +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] result[i] = &proto.FirewallRule{ - PeerIP: update[i].PeerIP, - Direction: direction, - Action: action, - Protocol: protocol, - Port: update[i].Port, + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/route.go b/management/server/route.go index 6c1c8b1b3c0..39ee6170c77 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,9 +4,15 @@ import ( "context" "fmt" "net/netip" + "slices" + "strconv" + "strings" "unicode/utf8" "github.com/rs/xid" + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -15,6 +21,30 @@ import ( "github.com/netbirdio/netbird/route" ) +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} + // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) @@ -112,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -157,6 +187,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } } + if len(accessControlGroupIDs) > 0 { + err = validateGroups(accessControlGroupIDs, account.Groups) + if err != nil { + return nil, err + } + } + err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) if err != nil { return nil, err @@ -187,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.Enabled = enabled newRoute.Groups = groups newRoute.KeepRoute = keepRoute + newRoute.AccessControlGroups = accessControlGroupIDs if account.Routes == nil { account.Routes = make(map[route.ID]*route.Route) @@ -258,6 +296,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } } + if len(routeToSave.AccessControlGroups) > 0 { + err = validateGroups(routeToSave.AccessControlGroups, account.Groups) + if err != nil { + return err + } + } + err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err != nil { return err @@ -351,3 +396,248 @@ func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } + +// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + } + + return routesFirewallRules +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + } + } + + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == firewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch PolicyRuleProtocolType(protocol) { + case PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 4533c6b7e5c..b556816be7a 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "fmt" + "net" "net/netip" "testing" @@ -44,18 +46,19 @@ var existingDomains = domain.List{"example.com"} func TestCreateRoute(t *testing.T) { type input struct { - network netip.Prefix - domains domain.List - keepRoute bool - networkType route.NetworkType - netID route.NetID - peerKey string - peerGroupIDs []string - description string - masquerade bool - metric int - enabled bool - groups []string + network netip.Prefix + domains domain.List + keepRoute bool + networkType route.NetworkType + netID route.NetID + peerKey string + peerGroupIDs []string + description string + masquerade bool + metric int + enabled bool + groups []string + accessControlGroups []string } testCases := []struct { @@ -69,100 +72,107 @@ func TestCreateRoute(t *testing.T) { { name: "Happy Path Network", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Domains", inputArgs: input{ - domains: domain.List{"domain1", "domain2"}, - keepRoute: true, - networkType: route.DomainNetwork, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + domains: domain.List{"domain1", "domain2"}, + keepRoute: true, + networkType: route.DomainNetwork, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.0.2.0/32"), - Domains: domain.List{"domain1", "domain2"}, - NetworkType: route.DomainNetwork, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - KeepRoute: true, + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"domain1", "domain2"}, + NetworkType: route.DomainNetwork, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + KeepRoute: true, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Peer Groups", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1, routeGroup2}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1, routeGroup2}, + accessControlGroups: []string{routeGroup1, routeGroup2}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1, routeGroup2}, }, }, { name: "Both network and domains provided should fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - domains: domain.List{"domain1", "domain2"}, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + domains: domain.List{"domain1", "domain2"}, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -170,16 +180,17 @@ func TestCreateRoute(t *testing.T) { { name: "Both peer and peer_groups Provided Should Fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -423,13 +434,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -1037,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { baseRoute := &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "ha route", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1062,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1127,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { // no routes for peer in different groups // no routes when route is deleted baseRoute := &route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1153,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1467,3 +1480,300 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return am.Store.GetAccount(context.Background(), account.Id) } + +func TestAccount_getPeersRoutesFirewall(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + ) + + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: peer1Key, + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*nbgroup.Group{ + "routingPeer1": { + ID: "routingPeer1", + Name: "RoutingPeer1", + Peers: []string{ + "peerA", + }, + }, + "routingPeer2": { + ID: "routingPeer2", + Name: "RoutingPeer2", + Peers: []string{ + "peerD", + }, + }, + "route1": { + ID: "route1", + Name: "Route1", + Peers: []string{}, + }, + "route2": { + ID: "route2", + Name: "Route2", + Peers: []string{}, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + }, + Routes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "route1", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1", "routingPeer2"}, + Description: "Route1 ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"dev"}, + AccessControlGroups: []string{"route1"}, + }, + "route2": { + ID: "route2", + Network: existingNetwork, + NetID: "route2", + NetworkType: route.IPv4Network, + Peer: "peerE", + Description: "Allow", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"finance"}, + AccessControlGroups: []string{"route2"}, + }, + "route3": { + ID: "route3", + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"example.com"}, + NetID: "route3", + NetworkType: route.DomainNetwork, + Peer: "peerE", + Description: "Allow all traffic to routed DNS network", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"contractors"}, + AccessControlGroups: []string{}, + }, + }, + Policies: []*Policy{ + { + ID: "RuleRoute1", + Name: "Route1", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute1", + Name: "ruleRoute1", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{ + "dev", + }, + Destinations: []string{ + "route1", + }, + }, + }, + }, + { + ID: "RuleRoute2", + Name: "Route2", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute2", + Name: "ruleRoute2", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + PortRanges: []RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + Destinations: []string{ + "route2", + }, + }, + }, + }, + }, + } + + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + t.Run("check applied policies for the route", func(t *testing.T) { + route1 := account.Routes["route1"] + policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + assert.Len(t, policies, 1) + + route2 := account.Routes["route2"] + policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + assert.Len(t, policies, 1) + + route3 := account.Routes["route3"] + policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + assert.Len(t, policies, 0) + }) + + t.Run("check peer routes firewall rules", func(t *testing.T) { + routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + + expectedRoutesFirewallRules := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, + } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerE is a single routing peer for route 2 and route 3 + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + assert.Len(t, routesFirewallRules, 3) + + expectedRoutesFirewallRules = []*RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: existingNetwork.String(), + Protocol: "tcp", + PortRange: RulePortRange{Start: 80, End: 350}, + }, + { + SourceRanges: []string{"0.0.0.0/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + { + SourceRanges: []string{"::/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerC is part of route1 distribution groups but should not receive the routes firewall rules + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, routesFirewallRules, 0) + }) + +} diff --git a/route/route.go b/route/route.go index eb6c36bd8bc..e23801e6e9e 100644 --- a/route/route.go +++ b/route/route.go @@ -100,6 +100,7 @@ type Route struct { Metric int Enabled bool Groups []string `gorm:"serializer:json"` + AccessControlGroups []string `gorm:"serializer:json"` } // EventMeta returns activity event meta related to the route @@ -123,6 +124,7 @@ func (r *Route) Copy() *Route { Masquerade: r.Masquerade, Enabled: r.Enabled, Groups: slices.Clone(r.Groups), + AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route } @@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups) + slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } // IsDynamic returns if the route is dynamic, i.e. has domains From b7b08281336676f356c8e1032b1907c617b6439c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 15:14:09 +0200 Subject: [PATCH 19/81] [client] Adjust relay worker log level and message (#2683) --- client/internal/peer/worker_relay.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 6bb385d3e20..c02fccebc47 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -74,7 +74,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) if err != nil { if errors.Is(err, relayClient.ErrConnAlreadyExists) { - w.log.Infof("do not need to reopen relay connection") + w.log.Debugf("handled offer by reusing existing relay connection") return } w.log.Errorf("failed to open connection via Relay: %s", err) From 7e5d3bdfe2306f69ef5daab3c742c4d206c69406 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 2 Oct 2024 15:33:38 +0200 Subject: [PATCH 20/81] [signal] Move dummy signal message handling into dispatcher (#2686) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 5 ----- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index c29ba076347..e7137ce5bf5 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 1f6cbb785be..4563dc9335f 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757 h1:6XniCzDt+1jvXWMUY4EDH0Hi5RXbUOYB0A8XEQqSlZk= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20240929132730-cbef5d331757/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index 386ce72389f..63cc43bd7ef 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -71,11 +71,6 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - if msg.RemoteKey == "dummy" { - // Test message send during netbird status - return &proto.EncryptedMessage{}, nil - } - if _, found := s.registry.Get(msg.RemoteKey); found { s.forwardMessageToPeer(ctx, msg) return &proto.EncryptedMessage{}, nil From fd67892cb4fa0c4e4c23c0511796cc7ce9fe296c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 2 Oct 2024 18:24:22 +0200 Subject: [PATCH 21/81] [client] Refactor/iface pkg (#2646) Refactor the flat code structure --- .github/workflows/golang-test-freebsd.yml | 2 +- .github/workflows/golang-test-linux.yml | 2 +- client/android/client.go | 6 +-- client/cmd/login_test.go | 2 +- client/cmd/root_test.go | 2 +- client/cmd/up.go | 2 +- client/firewall/iface.go | 6 +-- client/firewall/iptables/manager_linux.go | 2 +- .../firewall/iptables/manager_linux_test.go | 2 +- client/firewall/nftables/acl_linux.go | 2 +- .../firewall/nftables/manager_linux_test.go | 2 +- client/firewall/uspfilter/uspfilter.go | 5 +- client/firewall/uspfilter/uspfilter_test.go | 21 ++++---- {iface => client/iface}/bind/bind.go | 0 {iface => client/iface}/bind/udp_mux.go | 0 .../iface}/bind/udp_mux_universal.go | 0 .../iface}/bind/udp_muxed_conn.go | 0 client/iface/configurer/err.go | 5 ++ .../iface/configurer/kernel_unix.go | 27 +++++----- {iface => client/iface/configurer}/name.go | 2 +- .../iface/configurer}/name_darwin.go | 2 +- {iface => client/iface/configurer}/uapi.go | 2 +- .../iface/configurer}/uapi_windows.go | 2 +- .../iface/configurer/usp.go | 24 ++++----- .../iface/configurer/usp_test.go | 2 +- client/iface/configurer/wgstats.go | 9 ++++ client/iface/device.go | 18 +++++++ .../iface/device/adapter.go | 2 +- {iface => client/iface/device}/address.go | 8 +-- .../iface/device/args.go | 2 +- .../iface/device/device_android.go | 54 +++++++++---------- .../iface/device/device_darwin.go | 49 ++++++++--------- .../iface/device/device_filter.go | 19 +++---- .../iface/device/device_filter_test.go | 13 ++--- .../iface/device/device_ios.go | 49 ++++++++--------- .../iface/device/device_kernel_unix.go | 31 +++++------ .../iface/device/device_netstack.go | 49 ++++++++--------- .../iface/device/device_usp_unix.go | 52 +++++++++--------- .../iface/device/device_windows.go | 47 ++++++++-------- client/iface/device/interface.go | 20 +++++++ .../iface/device/kernel_module.go | 2 +- .../iface/device/kernel_module_freebsd.go | 6 +-- .../iface/device/kernel_module_linux.go | 6 +-- .../iface/device/kernel_module_linux_test.go | 8 +-- .../iface/device/wg_link_freebsd.go | 5 +- .../iface/device/wg_link_linux.go | 2 +- {iface => client/iface/device}/wg_log.go | 2 +- client/iface/device/windows_guid.go | 4 ++ client/iface/device_android.go | 16 ++++++ {iface => client/iface}/freebsd/errors.go | 0 {iface => client/iface}/freebsd/iface.go | 0 .../iface}/freebsd/iface_internal_test.go | 0 {iface => client/iface}/freebsd/link.go | 0 {iface => client/iface}/iface.go | 53 +++++++++--------- {iface => client/iface}/iface_android.go | 9 ++-- {iface => client/iface}/iface_create.go | 0 {iface => client/iface}/iface_darwin.go | 13 ++--- {iface => client/iface}/iface_destroy_bsd.go | 0 .../iface}/iface_destroy_linux.go | 0 .../iface}/iface_destroy_mobile.go | 0 .../iface}/iface_destroy_windows.go | 0 {iface => client/iface}/iface_ios.go | 9 ++-- {iface => client/iface}/iface_moc.go | 24 +++++---- {iface => client/iface}/iface_test.go | 6 ++- {iface => client/iface}/iface_unix.go | 19 +++---- {iface => client/iface}/iface_windows.go | 15 +++--- {iface => client/iface}/iwginterface.go | 14 ++--- .../iface}/iwginterface_windows.go | 14 ++--- {iface => client/iface}/mocks/README.md | 0 {iface => client/iface}/mocks/filter.go | 2 +- .../iface}/mocks/iface/mocks/filter.go | 2 +- {iface => client/iface}/mocks/tun.go | 0 {iface => client/iface}/netstack/dialer.go | 0 {iface => client/iface}/netstack/env.go | 0 {iface => client/iface}/netstack/proxy.go | 0 {iface => client/iface}/netstack/tun.go | 0 client/internal/acl/manager_test.go | 2 +- client/internal/acl/mocks/iface_mapper.go | 5 +- client/internal/config.go | 2 +- client/internal/connect.go | 7 +-- client/internal/dns/response_writer_test.go | 2 +- client/internal/dns/server_test.go | 18 ++++--- client/internal/dns/wgiface.go | 10 ++-- client/internal/dns/wgiface_windows.go | 12 +++-- client/internal/engine.go | 13 ++--- client/internal/engine_test.go | 7 +-- client/internal/mobile_dependency.go | 4 +- client/internal/peer/conn.go | 5 +- client/internal/peer/conn_test.go | 2 +- client/internal/peer/status.go | 6 +-- client/internal/peer/worker_ice.go | 4 +- client/internal/routemanager/client.go | 2 +- client/internal/routemanager/dynamic/route.go | 2 +- client/internal/routemanager/manager.go | 5 +- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/mock.go | 2 +- .../internal/routemanager/server_android.go | 2 +- .../routemanager/server_nonandroid.go | 2 +- .../routemanager/sysctl/sysctl_linux.go | 2 +- .../routemanager/systemops/systemops.go | 2 +- .../systemops/systemops_generic.go | 2 +- .../systemops/systemops_generic_test.go | 2 +- iface/tun.go | 21 -------- iface/wg_configurer.go | 21 -------- util/net/net.go | 2 +- 105 files changed, 505 insertions(+), 438 deletions(-) rename {iface => client/iface}/bind/bind.go (100%) rename {iface => client/iface}/bind/udp_mux.go (100%) rename {iface => client/iface}/bind/udp_mux_universal.go (100%) rename {iface => client/iface}/bind/udp_muxed_conn.go (100%) create mode 100644 client/iface/configurer/err.go rename iface/wg_configurer_kernel_unix.go => client/iface/configurer/kernel_unix.go (83%) rename {iface => client/iface/configurer}/name.go (87%) rename {iface => client/iface/configurer}/name_darwin.go (86%) rename {iface => client/iface/configurer}/uapi.go (96%) rename {iface => client/iface/configurer}/uapi_windows.go (88%) rename iface/wg_configurer_usp.go => client/iface/configurer/usp.go (93%) rename iface/wg_configurer_usp_test.go => client/iface/configurer/usp_test.go (99%) create mode 100644 client/iface/configurer/wgstats.go create mode 100644 client/iface/device.go rename iface/tun_adapter.go => client/iface/device/adapter.go (94%) rename {iface => client/iface/device}/address.go (69%) rename iface/tun_args.go => client/iface/device/args.go (88%) rename iface/tun_android.go => client/iface/device/device_android.go (61%) rename iface/tun_darwin.go => client/iface/device/device_darwin.go (69%) rename iface/device_wrapper.go => client/iface/device/device_filter.go (81%) rename iface/device_wrapper_test.go => client/iface/device/device_filter_test.go (95%) rename iface/tun_ios.go => client/iface/device/device_ios.go (63%) rename iface/tun_kernel_unix.go => client/iface/device/device_kernel_unix.go (75%) rename iface/tun_netstack.go => client/iface/device/device_netstack.go (56%) rename iface/tun_usp_unix.go => client/iface/device/device_usp_unix.go (63%) rename iface/tun_windows.go => client/iface/device/device_windows.go (75%) create mode 100644 client/iface/device/interface.go rename iface/module.go => client/iface/device/kernel_module.go (92%) rename iface/module_freebsd.go => client/iface/device/kernel_module_freebsd.go (84%) rename iface/module_linux.go => client/iface/device/kernel_module_linux.go (98%) rename iface/module_linux_test.go => client/iface/device/kernel_module_linux_test.go (98%) rename iface/tun_link_freebsd.go => client/iface/device/wg_link_freebsd.go (95%) rename iface/tun_link_linux.go => client/iface/device/wg_link_linux.go (99%) rename {iface => client/iface/device}/wg_log.go (93%) create mode 100644 client/iface/device/windows_guid.go create mode 100644 client/iface/device_android.go rename {iface => client/iface}/freebsd/errors.go (100%) rename {iface => client/iface}/freebsd/iface.go (100%) rename {iface => client/iface}/freebsd/iface_internal_test.go (100%) rename {iface => client/iface}/freebsd/link.go (100%) rename {iface => client/iface}/iface.go (79%) rename {iface => client/iface}/iface_android.go (67%) rename {iface => client/iface}/iface_create.go (100%) rename {iface => client/iface}/iface_darwin.go (68%) rename {iface => client/iface}/iface_destroy_bsd.go (100%) rename {iface => client/iface}/iface_destroy_linux.go (100%) rename {iface => client/iface}/iface_destroy_mobile.go (100%) rename {iface => client/iface}/iface_destroy_windows.go (100%) rename {iface => client/iface}/iface_ios.go (59%) rename {iface => client/iface}/iface_moc.go (76%) rename {iface => client/iface}/iface_test.go (98%) rename {iface => client/iface}/iface_unix.go (53%) rename {iface => client/iface}/iface_windows.go (52%) rename {iface => client/iface}/iwginterface.go (65%) rename {iface => client/iface}/iwginterface_windows.go (65%) rename {iface => client/iface}/mocks/README.md (100%) rename {iface => client/iface}/mocks/filter.go (97%) rename {iface => client/iface}/mocks/iface/mocks/filter.go (97%) rename {iface => client/iface}/mocks/tun.go (100%) rename {iface => client/iface}/netstack/dialer.go (100%) rename {iface => client/iface}/netstack/env.go (100%) rename {iface => client/iface}/netstack/proxy.go (100%) rename {iface => client/iface}/netstack/tun.go (100%) delete mode 100644 iface/tun.go delete mode 100644 iface/wg_configurer.go diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 4f13ee30e63..a2d743715fa 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -38,7 +38,7 @@ jobs: time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./formatter/... - time go test -timeout 1m -failfast ./iface/... + time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./signal/... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 2d5cf2856ce..524f35f6f47 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -80,7 +80,7 @@ jobs: run: git --no-pager diff --exit-code - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/ + run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock diff --git a/client/android/client.go b/client/android/client.go index d937e132e35..229bcd97409 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -15,7 +16,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util/net" ) @@ -26,7 +26,7 @@ type ConnectionListener interface { // TunAdapter export internal TunAdapter for mobile type TunAdapter interface { - iface.TunAdapter + device.TunAdapter } // IFaceDiscover export internal IFaceDiscover for mobile @@ -51,7 +51,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string - tunAdapter iface.TunAdapter + tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status ctxCancel context.CancelFunc diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go index 6bb7eff4f8a..fa20435ea6e 100644 --- a/client/cmd/login_test.go +++ b/client/cmd/login_test.go @@ -5,8 +5,8 @@ import ( "strings" "testing" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go index f2805cf35fb..4cbbe8783ed 100644 --- a/client/cmd/root_test.go +++ b/client/cmd/root_test.go @@ -7,7 +7,7 @@ import ( "github.com/spf13/cobra" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) func TestInitCommands(t *testing.T) { diff --git a/client/cmd/up.go b/client/cmd/up.go index b447f714104..05ecce9e0fd 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -15,11 +15,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/firewall/iface.go b/client/firewall/iface.go index d0b5209c040..f349f9210a6 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,13 +1,13 @@ package firewall import ( - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface/device" ) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string - Address() iface.WGAddress + Address() device.WGAddress IsUserspaceBind() bool - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index fae41d9c5a9..6fefd58e67e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -11,7 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) // Manager of iptables firewall diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 0072aa15961..498d8f58b09 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) var ifaceMock = &iFaceMock{ diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 85cba9e1cc2..eaf7fb6a023 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 7f78a9a2e02..904050a517f 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -15,7 +15,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) var ifaceMock = &iFaceMock{ diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 681058ea949..0e3ee97991f 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -12,7 +12,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) const layerTypeAll = 0 @@ -23,7 +24,7 @@ var ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error Address() iface.WGAddress } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index dd7366fe93d..c188deea460 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -11,15 +11,16 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) type IFaceMock struct { - SetFilterFunc func(iface.PacketFilter) error + SetFilterFunc func(device.PacketFilter) error AddressFunc func() iface.WGAddress } -func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { +func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { if i.SetFilterFunc == nil { return fmt.Errorf("not implemented") } @@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress { func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -52,7 +53,7 @@ func TestManagerCreate(t *testing.T) { func TestManagerAddPeerFiltering(t *testing.T) { isSetFilterCalled := false ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { + SetFilterFunc: func(device.PacketFilter) error { isSetFilterCalled = true return nil }, @@ -90,7 +91,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { func TestManagerDeleteRule(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) { func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) { func TestRemovePacketHook(t *testing.T) { // creating mock iface iface := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } // creating manager instance @@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } manager, err := Create(ifaceMock) require.NoError(t, err) diff --git a/iface/bind/bind.go b/client/iface/bind/bind.go similarity index 100% rename from iface/bind/bind.go rename to client/iface/bind/bind.go diff --git a/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go similarity index 100% rename from iface/bind/udp_mux.go rename to client/iface/bind/udp_mux.go diff --git a/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go similarity index 100% rename from iface/bind/udp_mux_universal.go rename to client/iface/bind/udp_mux_universal.go diff --git a/iface/bind/udp_muxed_conn.go b/client/iface/bind/udp_muxed_conn.go similarity index 100% rename from iface/bind/udp_muxed_conn.go rename to client/iface/bind/udp_muxed_conn.go diff --git a/client/iface/configurer/err.go b/client/iface/configurer/err.go new file mode 100644 index 00000000000..a64bba2dd55 --- /dev/null +++ b/client/iface/configurer/err.go @@ -0,0 +1,5 @@ +package configurer + +import "errors" + +var ErrPeerNotFound = errors.New("peer not found") diff --git a/iface/wg_configurer_kernel_unix.go b/client/iface/configurer/kernel_unix.go similarity index 83% rename from iface/wg_configurer_kernel_unix.go rename to client/iface/configurer/kernel_unix.go index 8b47082da83..7c1c416697c 100644 --- a/iface/wg_configurer_kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package configurer import ( "fmt" @@ -12,18 +12,17 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wgKernelConfigurer struct { +type KernelConfigurer struct { deviceName string } -func newWGConfigurer(deviceName string) wgConfigurer { - wgc := &wgKernelConfigurer{ +func NewKernelConfigurer(deviceName string) *KernelConfigurer { + return &KernelConfigurer{ deviceName: deviceName, } - return wgc } -func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error { +func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err return nil } -func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -75,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA return nil } -func (c *wgKernelConfigurer) removePeer(peerKey string) error { +func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -96,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error { return nil } -func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -123,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro return nil } -func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return fmt.Errorf("parse allowed IP: %w", err) @@ -165,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e return nil } -func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { +func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err) @@ -189,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer return wgtypes.Peer{}, ErrPeerNotFound } -func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { +func (c *KernelConfigurer) configure(config wgtypes.Config) error { wg, err := wgctrl.New() if err != nil { return err @@ -205,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { return wg.ConfigureDevice(c.deviceName, config) } -func (c *wgKernelConfigurer) close() { +func (c *KernelConfigurer) Close() { } -func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) { +func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { peer, err := c.getPeer(c.deviceName, peerKey) if err != nil { return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) diff --git a/iface/name.go b/client/iface/configurer/name.go similarity index 87% rename from iface/name.go rename to client/iface/configurer/name.go index 706cb65ad3c..e2133d0ead2 100644 --- a/iface/name.go +++ b/client/iface/configurer/name.go @@ -1,6 +1,6 @@ //go:build linux || windows || freebsd -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "wt0" diff --git a/iface/name_darwin.go b/client/iface/configurer/name_darwin.go similarity index 86% rename from iface/name_darwin.go rename to client/iface/configurer/name_darwin.go index a4016ce153b..034ce388d5c 100644 --- a/iface/name_darwin.go +++ b/client/iface/configurer/name_darwin.go @@ -1,6 +1,6 @@ //go:build darwin -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "utun100" diff --git a/iface/uapi.go b/client/iface/configurer/uapi.go similarity index 96% rename from iface/uapi.go rename to client/iface/configurer/uapi.go index d7ff52e7b35..4801841ded6 100644 --- a/iface/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,6 +1,6 @@ //go:build !windows -package iface +package configurer import ( "net" diff --git a/iface/uapi_windows.go b/client/iface/configurer/uapi_windows.go similarity index 88% rename from iface/uapi_windows.go rename to client/iface/configurer/uapi_windows.go index e1f4663642f..46fa90c2ebc 100644 --- a/iface/uapi_windows.go +++ b/client/iface/configurer/uapi_windows.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "net" diff --git a/iface/wg_configurer_usp.go b/client/iface/configurer/usp.go similarity index 93% rename from iface/wg_configurer_usp.go rename to client/iface/configurer/usp.go index cd1d9d0b6c9..21d65ab2a5d 100644 --- a/iface/wg_configurer_usp.go +++ b/client/iface/configurer/usp.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" @@ -19,15 +19,15 @@ import ( var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") -type wgUSPConfigurer struct { +type WGUSPConfigurer struct { device *device.Device deviceName string uapiListener net.Listener } -func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { - wgCfg := &wgUSPConfigurer{ +func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer { + wgCfg := &WGUSPConfigurer{ device: device, deviceName: deviceName, } @@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { return wgCfg } -func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error { +func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -80,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removePeer(peerKey string) error { +func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -97,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -121,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { +func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error { ipc, err := c.device.IpcGet() if err != nil { return err @@ -185,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { } // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool -func (t *wgUSPConfigurer) startUAPI() { +func (t *WGUSPConfigurer) startUAPI() { var err error t.uapiListener, err = openUAPI(t.deviceName) if err != nil { @@ -207,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() { }(t.uapiListener) } -func (t *wgUSPConfigurer) close() { +func (t *WGUSPConfigurer) Close() { if t.uapiListener != nil { err := t.uapiListener.Close() if err != nil { @@ -223,7 +223,7 @@ func (t *wgUSPConfigurer) close() { } } -func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) { +func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { ipc, err := t.device.IpcGet() if err != nil { return WGStats{}, fmt.Errorf("ipc get: %w", err) diff --git a/iface/wg_configurer_usp_test.go b/client/iface/configurer/usp_test.go similarity index 99% rename from iface/wg_configurer_usp_test.go rename to client/iface/configurer/usp_test.go index ac0fc613090..775339f24a1 100644 --- a/iface/wg_configurer_usp_test.go +++ b/client/iface/configurer/usp_test.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" diff --git a/client/iface/configurer/wgstats.go b/client/iface/configurer/wgstats.go new file mode 100644 index 00000000000..56d0d73109f --- /dev/null +++ b/client/iface/configurer/wgstats.go @@ -0,0 +1,9 @@ +package configurer + +import "time" + +type WGStats struct { + LastHandshake time.Time + TxBytes int64 + RxBytes int64 +} diff --git a/client/iface/device.go b/client/iface/device.go new file mode 100644 index 00000000000..0d4e6914554 --- /dev/null +++ b/client/iface/device.go @@ -0,0 +1,18 @@ +//go:build !android + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create() (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/tun_adapter.go b/client/iface/device/adapter.go similarity index 94% rename from iface/tun_adapter.go rename to client/iface/device/adapter.go index adec93ed198..6ebc0539007 100644 --- a/iface/tun_adapter.go +++ b/client/iface/device/adapter.go @@ -1,4 +1,4 @@ -package iface +package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { diff --git a/iface/address.go b/client/iface/device/address.go similarity index 69% rename from iface/address.go rename to client/iface/device/address.go index 5ff4fbc0645..15de301daee 100644 --- a/iface/address.go +++ b/client/iface/device/address.go @@ -1,18 +1,18 @@ -package iface +package device import ( "fmt" "net" ) -// WGAddress Wireguard parsed address +// WGAddress WireGuard parsed address type WGAddress struct { IP net.IP Network *net.IPNet } -// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address -func parseWGAddress(address string) (WGAddress, error) { +// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address +func ParseWGAddress(address string) (WGAddress, error) { ip, network, err := net.ParseCIDR(address) if err != nil { return WGAddress{}, err diff --git a/iface/tun_args.go b/client/iface/device/args.go similarity index 88% rename from iface/tun_args.go rename to client/iface/device/args.go index 0eac2c4c0ed..d7b86b3351a 100644 --- a/iface/tun_args.go +++ b/client/iface/device/args.go @@ -1,4 +1,4 @@ -package iface +package device type MobileIFaceArguments struct { TunAdapter TunAdapter // only for Android diff --git a/iface/tun_android.go b/client/iface/device/device_android.go similarity index 61% rename from iface/tun_android.go rename to client/iface/device/device_android.go index 50499309413..29e3f409df6 100644 --- a/iface/tun_android.go +++ b/client/iface/device/device_android.go @@ -1,7 +1,6 @@ //go:build android -// +build android -package iface +package device import ( "strings" @@ -12,11 +11,12 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform -type wgTunDevice struct { +// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform +type WGTunDevice struct { address WGAddress port int key string @@ -24,15 +24,15 @@ type wgTunDevice struct { iceBind *bind.ICEBind tunAdapter TunAdapter - name string - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + name string + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { - return wgTunDevice{ +func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { + return &WGTunDevice{ address: address, port: port, key: key, @@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet } } -func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) { +func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { log.Info("create tun interface") routesString := routesToString(routes) @@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } t.name = name - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debugf("attaching to interface %v", name) - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *wgTunDevice) UpdateAddr(addr WGAddress) error { +func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *wgTunDevice) Close() error { +func (t *WGTunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error { return nil } -func (t *wgTunDevice) Device() *device.Device { +func (t *WGTunDevice) Device() *device.Device { return t.device } -func (t *wgTunDevice) DeviceName() string { +func (t *WGTunDevice) DeviceName() string { return t.name } -func (t *wgTunDevice) WgAddress() WGAddress { +func (t *WGTunDevice) WgAddress() WGAddress { return t.address } -func (t *wgTunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *WGTunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } func routesToString(routes []string) string { diff --git a/iface/tun_darwin.go b/client/iface/device/device_darwin.go similarity index 69% rename from iface/tun_darwin.go rename to client/iface/device/device_darwin.go index fcf9f8ba092..03e85a7f17f 100644 --- a/iface/tun_darwin.go +++ b/client/iface/device/device_darwin.go @@ -1,6 +1,6 @@ //go:build !ios -package iface +package device import ( "fmt" @@ -11,10 +11,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -22,14 +23,14 @@ type tunDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -39,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int, } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -59,17 +60,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -84,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -105,20 +106,20 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) diff --git a/iface/device_wrapper.go b/client/iface/device/device_filter.go similarity index 81% rename from iface/device_wrapper.go rename to client/iface/device/device_filter.go index 2fa21939573..f87f104293c 100644 --- a/iface/device_wrapper.go +++ b/client/iface/device/device_filter.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -28,22 +28,23 @@ type PacketFilter interface { SetNetwork(*net.IPNet) } -// DeviceWrapper to override Read or Write of packets -type DeviceWrapper struct { +// FilteredDevice to override Read or Write of packets +type FilteredDevice struct { tun.Device + filter PacketFilter mutex sync.RWMutex } -// newDeviceWrapper constructor function -func newDeviceWrapper(device tun.Device) *DeviceWrapper { - return &DeviceWrapper{ +// newDeviceFilter constructor function +func newDeviceFilter(device tun.Device) *FilteredDevice { + return &FilteredDevice{ Device: device, } } // Read wraps read method with filtering feature -func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } @@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err } // Write wraps write method with filtering feature -func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { +func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { } // SetFilter sets packet filter to device -func (d *DeviceWrapper) SetFilter(filter PacketFilter) { +func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.mutex.Lock() d.filter = filter d.mutex.Unlock() diff --git a/iface/device_wrapper_test.go b/client/iface/device/device_filter_test.go similarity index 95% rename from iface/device_wrapper_test.go rename to client/iface/device/device_filter_test.go index 2d3725ea4bd..d3278b91805 100644 --- a/iface/device_wrapper_test.go +++ b/client/iface/device/device_filter_test.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -7,7 +7,8 @@ import ( "github.com/golang/mock/gomock" "github.com/google/gopacket" "github.com/google/gopacket/layers" - mocks "github.com/netbirdio/netbird/iface/mocks" + + mocks "github.com/netbirdio/netbird/client/iface/mocks" ) func TestDeviceWrapperRead(t *testing.T) { @@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) { return 1, nil }) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{{}} sizes := []int{0} @@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) { tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(1, nil) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{buffer.Bytes()} @@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropIncoming(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{buffer.Bytes()} @@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{{}} diff --git a/iface/tun_ios.go b/client/iface/device/device_ios.go similarity index 63% rename from iface/tun_ios.go rename to client/iface/device/device_ios.go index 6d53cc33366..226e8a2e0cb 100644 --- a/iface/tun_ios.go +++ b/client/iface/device/device_ios.go @@ -1,7 +1,7 @@ //go:build ios // +build ios -package iface +package device import ( "os" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -23,14 +24,14 @@ type tunDevice struct { iceBind *bind.ICEBind tunFd int - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { log.Infof("create tun interface") dupTunFd, err := unix.Dup(t.tunFd) @@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, err } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debug("Attaching to interface") - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) Device() *device.Device { +func (t *TunDevice) Device() *device.Device { return t.device } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -119,15 +120,15 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) UpdateAddr(addr WGAddress) error { +func (t *TunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_kernel_unix.go b/client/iface/device/device_kernel_unix.go similarity index 75% rename from iface/tun_kernel_unix.go rename to client/iface/device/device_kernel_unix.go index 220c078882b..f355d2cf76a 100644 --- a/iface/tun_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "context" @@ -10,11 +10,12 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/sharedsock" ) -type tunKernelDevice struct { +type TunKernelDevice struct { name string address WGAddress wgPort int @@ -31,11 +32,11 @@ type tunKernelDevice struct { filterFn bind.FilterFn } -func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { checkUser() ctx, cancel := context.WithCancel(context.Background()) - return &tunKernelDevice{ + return &TunKernelDevice{ ctx: ctx, ctxCancel: cancel, name: name, @@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in } } -func (t *tunKernelDevice) Create() (wgConfigurer, error) { +func (t *TunKernelDevice) Create() (WGConfigurer, error) { link := newWGLink(t.name) if err := link.recreate(); err != nil { @@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("set mtu: %w", err) } - configurer := newWGConfigurer(t.name) + configurer := configurer.NewKernelConfigurer(t.name) - if err := configurer.configureInterface(t.key, t.wgPort); err != nil { + if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil { return nil, fmt.Errorf("error configuring interface: %s", err) } return configurer, nil } -func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return t.udpMux, nil } -func (t *tunKernelDevice) UpdateAddr(address WGAddress) error { +func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunKernelDevice) Close() error { +func (t *TunKernelDevice) Close() error { if t.link == nil { return nil } @@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error { return closErr } -func (t *tunKernelDevice) WgAddress() WGAddress { +func (t *TunKernelDevice) WgAddress() WGAddress { return t.address } -func (t *tunKernelDevice) DeviceName() string { +func (t *TunKernelDevice) DeviceName() string { return t.name } -func (t *tunKernelDevice) Wrapper() *DeviceWrapper { +func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } // assignAddr Adds IP address to the tunnel interface -func (t *tunKernelDevice) assignAddr() error { +func (t *TunKernelDevice) assignAddr() error { return t.link.assignAddr(t.address) } diff --git a/iface/tun_netstack.go b/client/iface/device/device_netstack.go similarity index 56% rename from iface/tun_netstack.go rename to client/iface/device/device_netstack.go index de1ff6654dc..440a1ca191e 100644 --- a/iface/tun_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,7 +1,7 @@ //go:build !android // +build !android -package iface +package device import ( "fmt" @@ -10,11 +10,12 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/netstack" ) -type tunNetstackDevice struct { +type TunNetstackDevice struct { name string address WGAddress port int @@ -23,15 +24,15 @@ type tunNetstackDevice struct { listenAddress string iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - nsTun *netstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + nsTun *netstack.NetStackTun + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { - return &tunNetstackDevice{ +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { + return &TunNetstackDevice{ name: name, address: address, port: wgPort, @@ -42,23 +43,23 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string } } -func (t *tunNetstackDevice) Create() (wgConfigurer, error) { +func (t *TunNetstackDevice) Create() (WGConfigurer, error) { log.Info("create netstack tun interface") t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) tunIface, err := t.nsTun.Create() if err != nil { return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() return nil, fmt.Errorf("error configuring interface: %s", err) @@ -68,7 +69,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) { return t.configurer, nil } -func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunNetstackDevice) UpdateAddr(WGAddress) error { +func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { return nil } -func (t *tunNetstackDevice) Close() error { +func (t *TunNetstackDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error { return nil } -func (t *tunNetstackDevice) WgAddress() WGAddress { +func (t *TunNetstackDevice) WgAddress() WGAddress { return t.address } -func (t *tunNetstackDevice) DeviceName() string { +func (t *TunNetstackDevice) DeviceName() string { return t.name } -func (t *tunNetstackDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_usp_unix.go b/client/iface/device/device_usp_unix.go similarity index 63% rename from iface/tun_usp_unix.go rename to client/iface/device/device_usp_unix.go index 1c1d3ac89bf..4175f65569e 100644 --- a/iface/tun_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "fmt" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunUSPDevice struct { +type USPDevice struct { name string address WGAddress port int @@ -23,39 +24,38 @@ type tunUSPDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { log.Infof("using userspace bind mode") checkUser() - return &tunUSPDevice{ + return &USPDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), - } + iceBind: bind.NewICEBind(transportNet, filterFn)} } -func (t *tunUSPDevice) Create() (wgConfigurer, error) { +func (t *USPDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") tunIface, err := tun.CreateTUN(t.name, t.mtu) if err != nil { log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -66,17 +66,17 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunUSPDevice) UpdateAddr(address WGAddress) error { +func (t *USPDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunUSPDevice) Close() error { +func (t *USPDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error { return nil } -func (t *tunUSPDevice) WgAddress() WGAddress { +func (t *USPDevice) WgAddress() WGAddress { return t.address } -func (t *tunUSPDevice) DeviceName() string { +func (t *USPDevice) DeviceName() string { return t.name } -func (t *tunUSPDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *USPDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface -func (t *tunUSPDevice) assignAddr() error { +func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) return link.assignAddr(t.address) diff --git a/iface/tun_windows.go b/client/iface/device/device_windows.go similarity index 75% rename from iface/tun_windows.go rename to client/iface/device/device_windows.go index afb67bcc022..f3e216ccd5d 100644 --- a/iface/tun_windows.go +++ b/client/iface/device/device_windows.go @@ -1,4 +1,4 @@ -package iface +package device import ( "fmt" @@ -11,12 +11,13 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -26,13 +27,13 @@ type tunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun - wrapper *DeviceWrapper + filteredDevice *FilteredDevice udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -50,7 +51,7 @@ func getGUID() (windows.GUID, error) { return windows.GUIDFromString(guidString) } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { guid, err := getGUID() if err != nil { log.Errorf("failed to get GUID: %s", err) @@ -62,11 +63,11 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error creating tun device: %s", err) } t.nativeTunDevice = tunDevice.(*tun.NativeTun) - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -92,17 +93,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -117,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -138,19 +139,19 @@ func (t *tunDevice) Close() error { } return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } -func (t *tunDevice) getInterfaceGUIDString() (string, error) { +func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") } @@ -164,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) { } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { luid := winipcfg.LUID(t.nativeTunDevice.LUID()) log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go new file mode 100644 index 00000000000..0196b0085e1 --- /dev/null +++ b/client/iface/device/interface.go @@ -0,0 +1,20 @@ +package device + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type WGConfigurer interface { + ConfigureInterface(privateKey string, port int) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() + GetStats(peerKey string) (configurer.WGStats, error) +} diff --git a/iface/module.go b/client/iface/device/kernel_module.go similarity index 92% rename from iface/module.go rename to client/iface/device/kernel_module.go index ca70cf3c7de..1bdd6f7c6d9 100644 --- a/iface/module.go +++ b/client/iface/device/kernel_module.go @@ -1,6 +1,6 @@ //go:build (!linux && !freebsd) || android -package iface +package device // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) func WireGuardModuleIsLoaded() bool { diff --git a/iface/module_freebsd.go b/client/iface/device/kernel_module_freebsd.go similarity index 84% rename from iface/module_freebsd.go rename to client/iface/device/kernel_module_freebsd.go index 00ad882c29a..dd6c8b40826 100644 --- a/iface/module_freebsd.go +++ b/client/iface/device/kernel_module_freebsd.go @@ -1,4 +1,4 @@ -package iface +package device // WireGuardModuleIsLoaded check if kernel support wireguard func WireGuardModuleIsLoaded() bool { @@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool { return false } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { // Assume tun supported by freebsd kernel by default // TODO: implement check for module loaded in kernel or build-it return true diff --git a/iface/module_linux.go b/client/iface/device/kernel_module_linux.go similarity index 98% rename from iface/module_linux.go rename to client/iface/device/kernel_module_linux.go index 11c0482d58a..0d195779dfe 100644 --- a/iface/module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -1,7 +1,7 @@ //go:build linux && !android // Package iface provides wireguard network interface creation and management -package iface +package device import ( "bufio" @@ -66,8 +66,8 @@ func getModuleRoot() string { return filepath.Join(moduleLibDir, string(uname.Release[:i])) } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { _, err := os.Stat("/dev/net/tun") if err == nil { return true diff --git a/iface/module_linux_test.go b/client/iface/device/kernel_module_linux_test.go similarity index 98% rename from iface/module_linux_test.go rename to client/iface/device/kernel_module_linux_test.go index 97e9b1f78b5..de9656e470e 100644 --- a/iface/module_linux_test.go +++ b/client/iface/device/kernel_module_linux_test.go @@ -1,4 +1,6 @@ -package iface +//go:build linux && !android + +package device import ( "bufio" @@ -132,7 +134,7 @@ func resetGlobals() { } func createFiles(t *testing.T) (string, []module) { - t.Helper() + t.Helper() writeFile := func(path, text string) { if err := os.WriteFile(path, []byte(text), 0644); err != nil { t.Fatal(err) @@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) { } func getRandomLoadedModule(t *testing.T) (string, error) { - t.Helper() + t.Helper() f, err := os.Open("/proc/modules") if err != nil { return "", err diff --git a/iface/tun_link_freebsd.go b/client/iface/device/wg_link_freebsd.go similarity index 95% rename from iface/tun_link_freebsd.go rename to client/iface/device/wg_link_freebsd.go index be7921fdb5e..104010f47af 100644 --- a/iface/tun_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -1,10 +1,11 @@ -package iface +package device import ( "fmt" - "github.com/netbirdio/netbird/iface/freebsd" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/freebsd" ) type wgLink struct { diff --git a/iface/tun_link_linux.go b/client/iface/device/wg_link_linux.go similarity index 99% rename from iface/tun_link_linux.go rename to client/iface/device/wg_link_linux.go index 3ce644e8452..a15cffe4852 100644 --- a/iface/tun_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -1,6 +1,6 @@ //go:build linux && !android -package iface +package device import ( "fmt" diff --git a/iface/wg_log.go b/client/iface/device/wg_log.go similarity index 93% rename from iface/wg_log.go rename to client/iface/device/wg_log.go index b44f6fc0b28..db2f3111f0f 100644 --- a/iface/wg_log.go +++ b/client/iface/device/wg_log.go @@ -1,4 +1,4 @@ -package iface +package device import ( "os" diff --git a/client/iface/device/windows_guid.go b/client/iface/device/windows_guid.go new file mode 100644 index 00000000000..1c7d40d1313 --- /dev/null +++ b/client/iface/device/windows_guid.go @@ -0,0 +1,4 @@ +package device + +// CustomWindowsGUIDString is a custom GUID string for the interface +var CustomWindowsGUIDString string diff --git a/client/iface/device_android.go b/client/iface/device_android.go new file mode 100644 index 00000000000..3d15080fff4 --- /dev/null +++ b/client/iface/device_android.go @@ -0,0 +1,16 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/freebsd/errors.go b/client/iface/freebsd/errors.go similarity index 100% rename from iface/freebsd/errors.go rename to client/iface/freebsd/errors.go diff --git a/iface/freebsd/iface.go b/client/iface/freebsd/iface.go similarity index 100% rename from iface/freebsd/iface.go rename to client/iface/freebsd/iface.go diff --git a/iface/freebsd/iface_internal_test.go b/client/iface/freebsd/iface_internal_test.go similarity index 100% rename from iface/freebsd/iface_internal_test.go rename to client/iface/freebsd/iface_internal_test.go diff --git a/iface/freebsd/link.go b/client/iface/freebsd/link.go similarity index 100% rename from iface/freebsd/link.go rename to client/iface/freebsd/link.go diff --git a/iface/iface.go b/client/iface/iface.go similarity index 79% rename from iface/iface.go rename to client/iface/iface.go index 545feffcfc2..accf5ce0afb 100644 --- a/iface/iface.go +++ b/client/iface/iface.go @@ -9,28 +9,27 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) const ( - DefaultMTU = 1280 - DefaultWgPort = 51820 + DefaultMTU = 1280 + DefaultWgPort = 51820 + WgInterfaceDefault = configurer.WgInterfaceDefault ) -// WGIface represents a interface instance +type WGAddress = device.WGAddress + +// WGIface represents an interface instance type WGIface struct { - tun wgTunDevice + tun WGTunDevice userspaceBind bool mu sync.Mutex - configurer wgConfigurer - filter PacketFilter -} - -type WGStats struct { - LastHandshake time.Time - TxBytes int64 - RxBytes int64 + configurer device.WGConfigurer + filter device.PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -44,7 +43,7 @@ func (w *WGIface) Name() string { } // Address returns the interface address -func (w *WGIface) Address() WGAddress { +func (w *WGIface) Address() device.WGAddress { return w.tun.WgAddress() } @@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := parseWGAddress(newAddr) + addr, err := device.ParseWGAddress(newAddr) if err != nil { return err } @@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D defer w.mu.Unlock() log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) - return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error { defer w.mu.Unlock() log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) - return w.configurer.removePeer(peerKey) + return w.configurer.RemovePeer(peerKey) } // AddAllowedIP adds a prefix to the allowed IPs list of peer @@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.addAllowedIP(peerKey, allowedIP) + return w.configurer.AddAllowedIP(peerKey, allowedIP) } // RemoveAllowedIP removes a prefix from the allowed IPs list of peer @@ -117,7 +116,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.removeAllowedIP(peerKey, allowedIP) + return w.configurer.RemoveAllowedIP(peerKey, allowedIP) } // Close closes the tunnel interface @@ -144,23 +143,23 @@ func (w *WGIface) Close() error { } // SetFilter sets packet filters for the userspace implementation -func (w *WGIface) SetFilter(filter PacketFilter) error { +func (w *WGIface) SetFilter(filter device.PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() - if w.tun.Wrapper() == nil { + if w.tun.FilteredDevice() == nil { return fmt.Errorf("userspace packet filtering not handled on this device") } w.filter = filter w.filter.SetNetwork(w.tun.WgAddress().Network) - w.tun.Wrapper().SetFilter(filter) + w.tun.FilteredDevice().SetFilter(filter) return nil } // GetFilter returns packet filter used by interface if it uses userspace device implementation -func (w *WGIface) GetFilter() PacketFilter { +func (w *WGIface) GetFilter() device.PacketFilter { w.mu.Lock() defer w.mu.Unlock() @@ -168,16 +167,16 @@ func (w *WGIface) GetFilter() PacketFilter { } // GetDevice to interact with raw device (with filtering) -func (w *WGIface) GetDevice() *DeviceWrapper { +func (w *WGIface) GetDevice() *device.FilteredDevice { w.mu.Lock() defer w.mu.Unlock() - return w.tun.Wrapper() + return w.tun.FilteredDevice() } // GetStats returns the last handshake time, rx and tx bytes for the given peer -func (w *WGIface) GetStats(peerKey string) (WGStats, error) { - return w.configurer.getStats(peerKey) +func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { + return w.configurer.GetStats(peerKey) } func (w *WGIface) waitUntilRemoved() error { diff --git a/iface/iface_android.go b/client/iface/iface_android.go similarity index 67% rename from iface/iface_android.go rename to client/iface/iface_android.go index 99f6885a5e9..5ed476e7060 100644 --- a/iface/iface_android.go +++ b/client/iface/iface_android.go @@ -5,18 +5,19 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), + tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_create.go b/client/iface/iface_create.go similarity index 100% rename from iface/iface_create.go rename to client/iface/iface_create.go diff --git a/iface/iface_darwin.go b/client/iface/iface_darwin.go similarity index 68% rename from iface/iface_darwin.go rename to client/iface/iface_darwin.go index f48f324c362..b46ea0f8067 100644 --- a/iface/iface_darwin.go +++ b/client/iface/iface_darwin.go @@ -9,13 +9,14 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -25,11 +26,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } diff --git a/iface/iface_destroy_bsd.go b/client/iface/iface_destroy_bsd.go similarity index 100% rename from iface/iface_destroy_bsd.go rename to client/iface/iface_destroy_bsd.go diff --git a/iface/iface_destroy_linux.go b/client/iface/iface_destroy_linux.go similarity index 100% rename from iface/iface_destroy_linux.go rename to client/iface/iface_destroy_linux.go diff --git a/iface/iface_destroy_mobile.go b/client/iface/iface_destroy_mobile.go similarity index 100% rename from iface/iface_destroy_mobile.go rename to client/iface/iface_destroy_mobile.go diff --git a/iface/iface_destroy_windows.go b/client/iface/iface_destroy_windows.go similarity index 100% rename from iface/iface_destroy_windows.go rename to client/iface/iface_destroy_windows.go diff --git a/iface/iface_ios.go b/client/iface/iface_ios.go similarity index 59% rename from iface/iface_ios.go rename to client/iface/iface_ios.go index 6babe596419..fc0214748c1 100644 --- a/iface/iface_ios.go +++ b/client/iface/iface_ios.go @@ -7,17 +7,18 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), + tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_moc.go b/client/iface/iface_moc.go similarity index 76% rename from iface/iface_moc.go rename to client/iface/iface_moc.go index fab3054a092..703da9ce004 100644 --- a/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -6,7 +6,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type MockWGIface struct { @@ -14,7 +16,7 @@ type MockWGIface struct { CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error IsUserspaceBindFunc func() bool NameFunc func() string - AddressFunc func() WGAddress + AddressFunc func() device.WGAddress ToInterfaceFunc func() *net.Interface UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error @@ -23,10 +25,10 @@ type MockWGIface struct { AddAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP string) error CloseFunc func() error - SetFilterFunc func(filter PacketFilter) error - GetFilterFunc func() PacketFilter - GetDeviceFunc func() *DeviceWrapper - GetStatsFunc func(peerKey string) (WGStats, error) + SetFilterFunc func(filter device.PacketFilter) error + GetFilterFunc func() device.PacketFilter + GetDeviceFunc func() *device.FilteredDevice + GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) } @@ -50,7 +52,7 @@ func (m *MockWGIface) Name() string { return m.NameFunc() } -func (m *MockWGIface) Address() WGAddress { +func (m *MockWGIface) Address() device.WGAddress { return m.AddressFunc() } @@ -86,18 +88,18 @@ func (m *MockWGIface) Close() error { return m.CloseFunc() } -func (m *MockWGIface) SetFilter(filter PacketFilter) error { +func (m *MockWGIface) SetFilter(filter device.PacketFilter) error { return m.SetFilterFunc(filter) } -func (m *MockWGIface) GetFilter() PacketFilter { +func (m *MockWGIface) GetFilter() device.PacketFilter { return m.GetFilterFunc() } -func (m *MockWGIface) GetDevice() *DeviceWrapper { +func (m *MockWGIface) GetDevice() *device.FilteredDevice { return m.GetDeviceFunc() } -func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) { +func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } diff --git a/iface/iface_test.go b/client/iface/iface_test.go similarity index 98% rename from iface/iface_test.go rename to client/iface/iface_test.go index 8de9f647e94..87a68addbfc 100644 --- a/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" ) // keep darwin compatibility @@ -414,7 +416,7 @@ func Test_ConnectPeers(t *testing.T) { } guid := fmt.Sprintf("{%s}", uuid.New().String()) - CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { @@ -436,7 +438,7 @@ func Test_ConnectPeers(t *testing.T) { } guid = fmt.Sprintf("{%s}", uuid.New().String()) - CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) newNet, err = stdnet.NewNet() if err != nil { diff --git a/iface/iface_unix.go b/client/iface/iface_unix.go similarity index 53% rename from iface/iface_unix.go rename to client/iface/iface_unix.go index 9608df1ad9a..09dbb2c1f7d 100644 --- a/iface/iface_unix.go +++ b/client/iface/iface_unix.go @@ -8,13 +8,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, // move the kernel/usp/netstack preference evaluation to upper layer if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) wgIFace.userspaceBind = true return wgIFace, nil } - if WireGuardModuleIsLoaded() { - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.userspaceBind = false return wgIFace, nil } - if !tunModuleIsLoaded() { + if !device.ModuleTunIsLoaded() { return nil, fmt.Errorf("couldn't check or load tun module") } - wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) + wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) wgIFace.userspaceBind = true return wgIFace, nil } diff --git a/iface/iface_windows.go b/client/iface/iface_windows.go similarity index 52% rename from iface/iface_windows.go rename to client/iface/iface_windows.go index c5edd27a9ce..6845ef3ddd6 100644 --- a/iface/iface_windows.go +++ b/client/iface/iface_windows.go @@ -5,13 +5,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } @@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error { // GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.(*tunDevice).getInterfaceGUIDString() + return w.tun.(*device.TunDevice).GetInterfaceGUIDString() } diff --git a/iface/iwginterface.go b/client/iface/iwginterface.go similarity index 65% rename from iface/iwginterface.go rename to client/iface/iwginterface.go index 501f51d2b1f..cb6d7ccd9ad 100644 --- a/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -8,7 +8,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type IWGIface interface { @@ -16,7 +18,7 @@ type IWGIface interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() WGAddress + Address() device.WGAddress ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error @@ -25,8 +27,8 @@ type IWGIface interface { AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error Close() error - SetFilter(filter PacketFilter) error - GetFilter() PacketFilter - GetDevice() *DeviceWrapper - GetStats(peerKey string) (WGStats, error) + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go similarity index 65% rename from iface/iwginterface_windows.go rename to client/iface/iwginterface_windows.go index b5053474eca..6baeb66ae0e 100644 --- a/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -6,7 +6,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) type IWGIface interface { @@ -14,7 +16,7 @@ type IWGIface interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() WGAddress + Address() device.WGAddress ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error @@ -23,9 +25,9 @@ type IWGIface interface { AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error Close() error - SetFilter(filter PacketFilter) error - GetFilter() PacketFilter - GetDevice() *DeviceWrapper - GetStats(peerKey string) (WGStats, error) + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/iface/mocks/README.md b/client/iface/mocks/README.md similarity index 100% rename from iface/mocks/README.md rename to client/iface/mocks/README.md diff --git a/iface/mocks/filter.go b/client/iface/mocks/filter.go similarity index 97% rename from iface/mocks/filter.go rename to client/iface/mocks/filter.go index 2d80d69f163..6348e0e771d 100644 --- a/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go similarity index 97% rename from iface/mocks/iface/mocks/filter.go rename to client/iface/mocks/iface/mocks/filter.go index 059a2b9a01d..17e123abb94 100644 --- a/iface/mocks/iface/mocks/filter.go +++ b/client/iface/mocks/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/tun.go b/client/iface/mocks/tun.go similarity index 100% rename from iface/mocks/tun.go rename to client/iface/mocks/tun.go diff --git a/iface/netstack/dialer.go b/client/iface/netstack/dialer.go similarity index 100% rename from iface/netstack/dialer.go rename to client/iface/netstack/dialer.go diff --git a/iface/netstack/env.go b/client/iface/netstack/env.go similarity index 100% rename from iface/netstack/env.go rename to client/iface/netstack/env.go diff --git a/iface/netstack/proxy.go b/client/iface/netstack/proxy.go similarity index 100% rename from iface/netstack/proxy.go rename to client/iface/netstack/proxy.go diff --git a/iface/netstack/tun.go b/client/iface/netstack/tun.go similarity index 100% rename from iface/netstack/tun.go rename to client/iface/netstack/tun.go diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index eec3d3b8cf1..7d999669abb 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/acl/mocks" - "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 621b2951364..3ed12b6dd76 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -8,7 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - iface "github.com/netbirdio/netbird/iface" + iface "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) // MockIFaceMapper is a mock of IFaceMapper interface. @@ -77,7 +78,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call { } // SetFilter mocks base method. -func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error { +func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetFilter", arg0) ret0, _ := ret[0].(error) diff --git a/client/internal/config.go b/client/internal/config.go index 1df1e0547ae..ee54c6380c5 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -16,9 +16,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/util" ) diff --git a/client/internal/connect.go b/client/internal/connect.go index 36b340cfbe8..c77f95603d0 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -17,13 +17,14 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" mgmProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/relay/auth/hmac" @@ -70,7 +71,7 @@ func (c *ConnectClient) RunWithProbes( // RunOnAndroid with main logic on mobile system func (c *ConnectClient) RunOnAndroid( - tunAdapter iface.TunAdapter, + tunAdapter device.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, @@ -205,7 +206,7 @@ func (c *ConnectClient) run( localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: loginResp.GetPeerConfig().GetFqdn(), } c.statusRecorder.UpdateLocalPeerState(localPeerState) diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go index 5a00477007f..85796440680 100644 --- a/client/internal/dns/response_writer_test.go +++ b/client/internal/dns/response_writer_test.go @@ -9,7 +9,7 @@ import ( "github.com/google/gopacket/layers" "github.com/miekg/dns" - "github.com/netbirdio/netbird/iface/mocks" + "github.com/netbirdio/netbird/client/iface/mocks" ) func TestResponseWriterLocalAddr(t *testing.T) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index b9552bc17c0..53d18a67814 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -15,16 +15,18 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" + pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" - pfmock "github.com/netbirdio/netbird/iface/mocks" ) type mocWGIface struct { - filter iface.PacketFilter + filter device.PacketFilter } func (w *mocWGIface) Name() string { @@ -43,11 +45,11 @@ func (w *mocWGIface) ToInterface() *net.Interface { panic("implement me") } -func (w *mocWGIface) GetFilter() iface.PacketFilter { +func (w *mocWGIface) GetFilter() device.PacketFilter { return w.filter } -func (w *mocWGIface) GetDevice() *iface.DeviceWrapper { +func (w *mocWGIface) GetDevice() *device.FilteredDevice { panic("implement me") } @@ -59,13 +61,13 @@ func (w *mocWGIface) IsUserspaceBind() bool { return false } -func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error { +func (w *mocWGIface) SetFilter(filter device.PacketFilter) error { w.filter = filter return nil } -func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) { - return iface.WGStats{}, nil +func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) { + return configurer.WGStats{}, nil } var zoneRecords = []nbdns.SimpleRecord{ diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 2f08e8d52b8..69bc8365998 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -5,7 +5,9 @@ package dns import ( "net" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) // WGIface defines subset methods of interface required for manager @@ -14,7 +16,7 @@ type WGIface interface { Address() iface.WGAddress ToInterface() *net.Interface IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go index f8bb80fb934..765132fdbf2 100644 --- a/client/internal/dns/wgiface_windows.go +++ b/client/internal/dns/wgiface_windows.go @@ -1,14 +1,18 @@ package dns -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string Address() iface.WGAddress IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 998cbce2de1..c51901a225d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,9 +23,12 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" @@ -36,8 +39,6 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -619,7 +620,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ IP: e.config.WgAddr, PubKey: e.config.WgPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: conf.GetFqdn(), }) @@ -1165,15 +1166,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { log.Errorf("failed to create pion's stdnet: %s", err) } - var mArgs *iface.MobileIFaceArguments + var mArgs *device.MobileIFaceArguments switch runtime.GOOS { case "android": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunAdapter: e.mobileDep.TunAdapter, TunFd: int(e.mobileDep.FileDescriptor), } case "ios": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunFd: int(e.mobileDep.FileDescriptor), } default: diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 95aadf14186..29a8439a2df 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -25,14 +25,15 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgmt "github.com/netbirdio/netbird/management/client" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" @@ -874,7 +875,7 @@ func TestEngine_MultiplePeers(t *testing.T) { mu.Lock() defer mu.Unlock() guid := fmt.Sprintf("{%s}", uuid.New().String()) - iface.CustomWindowsGUIDString = strings.ToLower(guid) + device.CustomWindowsGUIDString = strings.ToLower(guid) err = engine.Start() if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2355c67c3bd..2b0c92cc690 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -1,16 +1,16 @@ package internal import ( + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" ) // MobileDependency collect all dependencies for mobile platform type MobileDependency struct { // Android only - TunAdapter iface.TunAdapter + TunAdapter device.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover NetworkChangeListener listener.NetworkChangeListener HostDNSAddresses []string diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index baff1372a16..ad84bd7006b 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -15,9 +15,10 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -684,7 +685,7 @@ func (conn *Conn) setStatusToDisconnected() { // todo rethink status updates conn.log.Debugf("error while updating peer's state, err: %v", err) } - if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil { + if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) } } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 22e5409f894..b4926a9d2ef 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -9,9 +9,9 @@ import ( "github.com/magiconair/properties/assert" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 915fa63f005..a28992fac13 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -11,8 +11,8 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/relay" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" relayClient "github.com/netbirdio/netbird/relay/client" ) @@ -203,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) { state, ok := d.peers[peerPubKey] if !ok { - return State{}, iface.ErrPeerNotFound + return State{}, configurer.ErrPeerNotFound } return state, nil } @@ -412,7 +412,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { } // UpdateWireGuardPeerState updates the WireGuard bits of the peer state -func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error { +func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error { d.mux.Lock() defer d.mux.Unlock() diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 8bf1b75684a..c4e9d195074 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -15,9 +15,9 @@ import ( "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index db2caea7f81..eaa23215135 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -10,12 +10,12 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index e86a5281086..ac94d4a5c74 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -13,10 +13,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d97fe631fe4..d7ddf7ae8b7 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -21,7 +23,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -102,7 +103,7 @@ func NewManager( }, func(prefix netip.Prefix, peerKey string) error { if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { - if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) { + if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { return err } log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2995e2740f8..2f26f7a5ec9 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 58a66715cd2..908279c885a 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -5,9 +5,9 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index 2057b9cc8ed..c75a0a7f22e 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -7,8 +7,8 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 1d1a4b0633e..ef38d57078f 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -11,9 +11,9 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index 13e1229f84f..bb620ee6893 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -13,7 +13,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 10944c1e22d..d1cb83bfbc3 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -5,9 +5,9 @@ import ( "net/netip" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/iface" ) type Nexthop struct { diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 90f06ba7835..9258f4a4e3b 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -16,10 +16,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" - "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 94965c119b9..238225807f8 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) type dialer interface { diff --git a/iface/tun.go b/iface/tun.go deleted file mode 100644 index 7d0a57ed6ab..00000000000 --- a/iface/tun.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !android -// +build !android - -package iface - -import ( - "github.com/netbirdio/netbird/iface/bind" -) - -// CustomWindowsGUIDString is a custom GUID string for the interface -var CustomWindowsGUIDString string - -type wgTunDevice interface { - Create() (wgConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(address WGAddress) error - WgAddress() WGAddress - DeviceName() string - Close() error - Wrapper() *DeviceWrapper // todo eliminate this function -} diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go deleted file mode 100644 index dd38ba0757a..00000000000 --- a/iface/wg_configurer.go +++ /dev/null @@ -1,21 +0,0 @@ -package iface - -import ( - "errors" - "net" - "time" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -var ErrPeerNotFound = errors.New("peer not found") - -type wgConfigurer interface { - configureInterface(privateKey string, port int) error - updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error - removePeer(peerKey string) error - addAllowedIP(peerKey string, allowedIP string) error - removeAllowedIP(peerKey string, allowedIP string) error - close() - getStats(peerKey string) (WGStats, error) -} diff --git a/util/net/net.go b/util/net/net.go index 8d1fcebd0af..61b47dbe7d3 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -4,7 +4,7 @@ import ( "net" "os" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) From 8934453b30309e508df09236ac102c0537259291 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Oct 2024 18:29:51 +0200 Subject: [PATCH 22/81] Update management base docker image (#2687) --- management/Dockerfile | 4 ++-- management/Dockerfile.debug | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/management/Dockerfile b/management/Dockerfile index cac640bf431..3b2df262395 100644 --- a/management/Dockerfile +++ b/management/Dockerfile @@ -1,5 +1,5 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt ENTRYPOINT [ "/go/bin/netbird-mgmt","management"] CMD ["--log-file", "console"] -COPY netbird-mgmt /go/bin/netbird-mgmt \ No newline at end of file +COPY netbird-mgmt /go/bin/netbird-mgmt diff --git a/management/Dockerfile.debug b/management/Dockerfile.debug index f4be366a801..4d9730bd780 100644 --- a/management/Dockerfile.debug +++ b/management/Dockerfile.debug @@ -1,4 +1,4 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"] CMD ["--log-file", "console"] From 158936fb15596690003d602c5df918f6522b97c1 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:50:35 +0200 Subject: [PATCH 23/81] [management] Remove file store (#2689) --- client/cmd/testutil_test.go | 13 +- client/internal/engine_test.go | 21 +- client/server/server_test.go | 2 +- client/testdata/store.json | 38 - client/testdata/store.sqlite | Bin 0 -> 163840 bytes management/client/client_test.go | 28 +- management/server/account_test.go | 2 +- management/server/dns_test.go | 2 +- management/server/file_store.go | 791 +----------------- management/server/file_store_test.go | 655 --------------- management/server/management_proto_test.go | 49 +- management/server/management_test.go | 9 +- management/server/nameserver_test.go | 2 +- management/server/peer_test.go | 18 +- management/server/route_test.go | 4 +- management/server/sql_store.go | 22 + management/server/sql_store_test.go | 144 ++-- management/server/store.go | 56 +- management/server/store_test.go | 21 +- .../server/testdata/extended-store.json | 120 --- .../server/testdata/extended-store.sqlite | Bin 0 -> 163840 bytes management/server/testdata/store.json | 88 -- management/server/testdata/store.sqlite | Bin 0 -> 163840 bytes .../server/testdata/store_policy_migrate.json | 116 --- .../testdata/store_policy_migrate.sqlite | Bin 0 -> 163840 bytes .../testdata/store_with_expired_peers.json | 130 --- .../testdata/store_with_expired_peers.sqlite | Bin 0 -> 163840 bytes management/server/testdata/storev1.json | 154 ---- management/server/testdata/storev1.sqlite | Bin 0 -> 163840 bytes management/server/user_test.go | 65 +- 30 files changed, 259 insertions(+), 2291 deletions(-) delete mode 100644 client/testdata/store.json create mode 100644 client/testdata/store.sqlite delete mode 100644 management/server/file_store_test.go delete mode 100644 management/server/testdata/extended-store.json create mode 100644 management/server/testdata/extended-store.sqlite delete mode 100644 management/server/testdata/store.json create mode 100644 management/server/testdata/store.sqlite delete mode 100644 management/server/testdata/store_policy_migrate.json create mode 100644 management/server/testdata/store_policy_migrate.sqlite delete mode 100644 management/server/testdata/store_with_expired_peers.json create mode 100644 management/server/testdata/store_with_expired_peers.sqlite delete mode 100644 management/server/testdata/storev1.json create mode 100644 management/server/testdata/storev1.sqlite diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index f0dc8bf214c..033d1bb6ab8 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -3,7 +3,6 @@ package cmd import ( "context" "net" - "path/filepath" "testing" "time" @@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string { if err != nil { t.Fatal(err) } - testDir := t.TempDir() - config.Datadir = testDir - err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } _, signalLis := startSignal(t) signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config) + _, mgmLis := startManagement(t, config, "../testdata/store.sqlite") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { return s, lis } -func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { +func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { t.Helper() lis, err := net.Listen("tcp", ":0") @@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 29a8439a2df..3d1983c6bda 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -6,7 +6,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "strings" "sync" @@ -824,20 +823,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { func TestEngine_MultiplePeers(t *testing.T) { // log.SetLevel(log.DebugLevel) - dir := t.TempDir() - - err := util.CopyFileContents("../testdata/store.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - err = os.Remove(filepath.Join(dir, "store.json")) //nolint - if err != nil { - t.Fatal(err) - return - } - }() - ctx, cancel := context.WithCancel(CtxInitState(context.Background())) defer cancel() @@ -847,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, dir) + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite") if err != nil { t.Fatal(err) return @@ -1070,7 +1055,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { return s, lis.Addr().String(), nil } -func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) { +func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) { t.Helper() config := &server.Config{ @@ -1095,7 +1080,7 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index 9b18df4d37f..e534ad7e2d6 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/testdata/store.json b/client/testdata/store.json deleted file mode 100644 index 8236f27031e..00000000000 --- a/client/testdata/store.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..118c2bebc9f1fd29751627c36304d301ba156781 GIT binary patch literal 163840 zcmeI5Piz}ke#beIMbfe*#>plVXQN$8iDR>oY?6`{$qU1kWjbrbmJ`Xwu3-ej5&1|> zM9wfXL)k*G1!yP1Vu512XpTMf)I$&3Vu3d2JrzCdvA|x7_S#zy1+s_!=FPm}d*slP z>)j$DzO>BYy!U(mKEL1l{ob45pPi3xTRzJ-9Jg)y`Q_A0DNRfLc|M;?rQW3f7wEru zyg@Hch!Z+$((6$#-%MR}|8^=6&V3Weyqf!F`ZuQ^PG3n+P3NY*p4yq*o4PQ0LNh16 zPW?9Z75$^fyDz8pYiG5TW!Kqb@6#4_&@j8c!_SST>vTJwQ3{W*^yXg5=pU_Xy{kOi zsy{Y5%=H}GY#C;)#yrpPoqc9|M%QDmVbzm!&ung2HttpOx3+FnewwcyT}?HAcPn2_ zuB>{Y8Z}leJlUDe=ZVVk!Lx6**w55nFR7B}y1?cCemxV3dJUv2DjU3f;vGC@2iBa5}AD)khf|*P2n{hnd~!L38Rb)tvbGsbO;(loLLN zN)MRR?UQN!yslZX#fC-q8Nngx=}~o)fLdvgQ22Q!!;SNo?Z`(_6}+CTYMeZ+dW3QE zCPHuIF~8d}qy(!Y4699MY3w3FY&wSHDPsG~wOLC^syVJOf+{c7X_KhyFL75&CJP3G_n#T`C}o8vQdIHDfk&1IkE=EB-3jM6MRd63TN*XFg} z+1Llg(_Ido{lv_$Q|f0*!8UYI{j34k_w=0;8GU|U`%>r5sxQ}nHKn?(e#Ue7E_tFz zA!kIS3(cIyt*ni@onT2+A4qlC0WC;nvn5lfg&8fo$YaDYFlk{`O$gCl&ulWIWwkA8 zJ1(7Iy3b)oePSSH$f008+;rB%ndHBJkF#brE5<~$~fN$AN z>W1ZUpK4H)A+ri%$-s&*7UjPF7!LQ@1*DF|f-Y-zTc#Ur5)Lf($QWLS#?sM#km5(j zY>x8!vGgYt?V}RIu^~yS)-tWO;;a=Zv(<7QjY34rUe~1UC%)aB3#DT5;_Ax9;>F^^;=&!%+o!Zkm)15mHm+V> z_}Rmcx4!z--BfCtI;Q1*pQ1lJKmY_l00ck)1V8`;KmY_l00ck)1ioVg&QIPyd7KLl z;r{;*soWpFV*^BYK>!3m00ck)1V8`;KmY_l00ck)1dc^ua&qQm`2PP~DwjJ}5eR_* z2!H?xfB*=900@8p2!H?xfWSBq=-FCs=KlS+=5}^=ve|2Avf1O@MBgZHu2ojc*O!*p zZeCtmxqNwf>H2c1vUKCd>gwj|`s!xo`t=*EVRl=7p1uvqW%QjL^u-hOX8+Psac!xz zo-Y@#tgK!s)0_2+m)Dn9*UIJa{r=_UKxBQXv^HAg=vs>O>4GKxmoFCSQv~$)E`NUb zG4B75gV|yXAOHd&00JNY0w4eaAOHd&00JN|nm};>AD{mp%?CLk00JNY0w4eaAOHd& z00JNY0wC~hCve=){|E2?ryu>p0|Y<-1V8`;KmY_l00ck)1V8`;Kwyjs1o!{3{~u!o zV`v}%0w4eaAOHd&00JNY0w4eaAV37L|A!fX00@8p2!H?xfB*=900@8p2!O!&6Ttp| z{B?{Gf&d7B00@8p2!H?xfB*=900@9U@cI9V*?&#t>hy#M2!H?xfB*=900@8p2!H?x zfB*=5+X?hupP0RW|J7`E=1lh5oma9mjg`h~ef7Qa(&}0Qiavk<2!H?xfB*=900@8p2!H?xfPg|E*#95r zO$&Vge;h&l{-44ZCm;X#3c|znss$B&$jB1jSh1?$2MDrS*tP6Gkj;C z*`5*n$f9A@L&SgDHj;aAW^=o;aj%lUwRNNN(|q;lu2%DRxAN8G%BmNt(SWLjCp)tl zeSKZ~e8V?)TP(Ria%wK>)4h%Bw=2A@KWb$OdZ^|v(9o-SpFQ>$<1s|1w%KMPYtQud zSUuF9NMnyXmTUT!V;gnTC(+>9x7uv6T5}n_t{c0DLRh4jzAV==JwLDo>Ox%j!|hvl zHntz+Kd3widskT8yt}n?Z+qj`*1deSvCoaiGdiZ>dq%UGzj=4Na_jxAs6ecU{C4GL zWxKMqS=q_+p?Rbx8f>AkJCo5b>soIrw$h-G7nrEiwSBTnem*;FB7<_FQSel%ykDLg zIi1$8>Z9gOSWk%`W}aR>mCwxrJ37F5k! zp4SZu2bPz=QMtKs|MtE70{wXE!eZ63eb$UD%ViDHM>WsWn46F(3Hyf~G5DjcC}J1- zq69;d*cKF95(Q&(njSQ#4pYsEpPw2wr$IU4W2p3iIo&>)*3avj6OXtv5rby}w1aNM>cHBgV^83!~na!`uYBR?^@Yb7R!ok`5H zng7Uf_YJdNcPX|pbrjyA=d*S&11$=FG)Z=utMnrh=(tY3Tl1Ca#)lZ?w@tg-Fl&C- z9VpN21Iu;nHnYj6h6>>28dj^{6&ZJ(0&mOQ?V}C!Yk0otcY_Ra#U5kI+?54tj%~9V zd39;V%lUlb8$2kClyTXoU5i*5W`~A-V3Lw!?-95+KGh>z^7dp)^5~S=K}40?m)Z;i zQG!S?u46hS*4bm^A(d{CmkXczFAAT6gk{Oqvq6OL+xi;{-L7ZIH$=Ro}u`- z=kQ&8h7Je8WBaZ`_E}*nti8k!ClAsY{o1_NI~)6;c)H7>v!9q5c1rzBDcFV%s-HFB z`kuaXBBRgGYhUWzS@q@muclPD)z5ei-z85JDUOT?U!j@PxRvEhw-c;l>I11RJD^3c zY_??Tw3eck5_ya`uqCa9stIwo>zPevw5+y8ZO5e(O!qm=s88$#Y-xEat-n7%YRMxt zQR=@=?@ng)@WqE(MxUG0enGB| z9Egw*#|hPXM6zl-V*Ps5iZ*I9^2gzsK^}xP9XDPK)50vu>9|(gB)1VYD!D@uT&z0z zcdKFboxjMTqwxD8i9GyZQW$?EO9`R?DZsbvCUwK|xKA}G%8*$Fv1DMG7mIRVe+-BF z>;h6pVnLTRyDig=*7XM#dt?kRE@SCv&q?v4V>U;5{aE@FiuO^7;nmX`vV2yh8OE3PjT7pq-z2^0`5@*s|@HZVMq3zsI+`sUoIxh725 zD2Z2-wA0Q8yX~3m?^1f|$Ej1L>0h1vuPfCO8=zT$?q z)zYA!y&Ky6L0bPwX_RFSvo%V?Y4eRYGx~dHwcg3txadNYAFLbWEqa*n!(k&Dk__WQ zm^@%CPtU!P(a)aMz6^~elnI~bRHKNB=jH1q{muwIMt)eF4k&S-v5vtjk`zrvq*ZT% z^#txjRLEV6YC6L8hijI2J46)=wg4ai}ufsFd0=^vKPu2a6eB@e~&LrpYHdKg``9adV-Mz zl^o@Sp4UP`{gaJ*!9@n6w#RDwLa|3{F(MO9EZ#r$ifkVpX-mE+D@sP1EL?i^S&JG$ z#Rzpq4ZVKqFTzDlWOM#Yo2&+ypt`CDAMVTIxI z%c^3dww^<+^3~Htx+xJ{FpT_zSWX7lDgGZWo`!vY0oVTA_l%P7o&3q}=X{V$>&E&> zXF}u?Zg0^glce;QQYr1dKibWqgoahuFmOdpjvU{u`#GPS%jlObX}!sCNk4emRV;zx z>#!Hp`8AXuVB+hp{^e}#rwJ%DZ&}O{Qmvd<&5nMy)iUs0u@Gju}|KbV;jvxR6 zAOHd&00JNY0w4eaAOHd&Funxv{eR=DTZ{|@KmY_l00ck)1V8`;KmY_l00aa9-2WFO za0CGm009sH0T2KI5C8!X009sHf$=4P{r~vt79#@z5C8!X009sH0T2KI5C8!X00BV& z@BbGha0CGm009sH0T2KI5C8!X009sHf$=4P`~TysTZ{|@KmY_l00ck)1V8`;KmY_l z00aa9?EeJ`96s009sH0T2KI z5C8!X009sHfk6V;{|^$uIS7CN2!H?xfB*=900@8p2!H?xj2{8K|9|{+iV=YT2!H?x VfB*=900@8p2!H?xfWRPu{{>_kT=f6| literal 0 HcmV?d00001 diff --git a/management/client/client_test.go b/management/client/client_test.go index a082e354b45..313a67617db 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -47,25 +47,18 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { level, _ := log.ParseLevel("debug") log.SetLevel(level) - testDir := t.TempDir() - config := &mgmt.Config{} _, err := util.ReadJson("../server/testdata/management.json", config) if err != nil { t.Fatal(err) } - config.Datadir = testDir - err = util.CopyFileContents("../server/testdata/store.json", filepath.Join(testDir, "store.json")) - if err != nil { - t.Fatal(err) - } lis, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) + store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite") if err != nil { t.Fatal(err) } @@ -521,3 +514,22 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") } + +func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) { + t.Helper() + dataDir := t.TempDir() + err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) + if err != nil { + t.Fatal(err) + } + + store, err := mgmt.NewSqliteStore(ctx, dataDir, nil) + if err != nil { + return nil, nil, err + } + + return store, func() { + store.Close(ctx) + os.Remove(filepath.Join(dataDir, "store.db")) + }, nil +} diff --git a/management/server/account_test.go b/management/server/account_test.go index e554ae493ea..198775bc33e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2366,7 +2366,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index e033c1a214f..23941495e8b 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/file_store.go b/management/server/file_store.go index 994a4b1eec5..df3e9bb7757 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,24 +2,18 @@ package server import ( "context" - "errors" - "net" "os" "path/filepath" "strings" "sync" "time" - "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/status" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" ) @@ -42,167 +36,9 @@ type FileStore struct { mux sync.Mutex `json:"-"` storeFile string `json:"-"` - // sync.Mutex indexed by resource ID - resourceLocks sync.Map `json:"-"` - globalAccountLock sync.Mutex `json:"-"` - metrics telemetry.AppMetrics `json:"-"` } -func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { - return f(s) -} - -func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] - if !ok { - return status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - account.SetupKeys[setupKeyID].UsedTimes++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - allGroup, err := account.GetGroupAll() - if err != nil || allGroup == nil { - return errors.New("all group not found") - } - - allGroup.Peers = append(allGroup.Peers, peerID) - - return nil -} - -func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountId) - if err != nil { - return err - } - - account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) - - return nil -} - -func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[peer.AccountID] - if !ok { - return status.NewAccountNotFoundError(peer.AccountID) - } - - account.Peers[peer.ID] = peer - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, ok := s.Accounts[accountId] - if !ok { - return status.NewAccountNotFoundError(accountId) - } - - account.Network.Serial++ - - return s.SaveAccount(ctx, account) -} - -func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - setupKey, ok := account.SetupKeys[key] - if !ok { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return setupKey, nil -} - -func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - var takenIps []net.IP - for _, existingPeer := range account.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps, nil -} - -func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - existingLabels := []string{} - for _, peer := range account.Peers { - if peer.DNSLabel != "" { - existingLabels = append(existingLabels, peer.DNSLabel) - } - } - return existingLabels, nil -} - -func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Network, nil -} - -type StoredAccount struct{} - // NewFileStore restores a store from the file located in the datadir func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { fs, err := restore(ctx, filepath.Join(dataDir, storeFileName)) @@ -213,25 +49,6 @@ func NewFileStore(ctx context.Context, dataDir string, metrics telemetry.AppMetr return fs, nil } -// NewFilestoreFromSqliteStore restores a store from Sqlite and stores to Filestore json in the file located in datadir -func NewFilestoreFromSqliteStore(ctx context.Context, sqlStore *SqlStore, dataDir string, metrics telemetry.AppMetrics) (*FileStore, error) { - store, err := NewFileStore(ctx, dataDir, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, sqlStore.GetInstallationID()) - if err != nil { - return nil, err - } - - for _, account := range sqlStore.GetAllAccounts(ctx) { - store.Accounts[account.Id] = account - } - - return store, store.persist(ctx, store.storeFile) -} - // restore the state of the store from the file. // Creates a new empty store file if doesn't exist func restore(ctx context.Context, file string) (*FileStore, error) { @@ -240,7 +57,6 @@ func restore(ctx context.Context, file string) (*FileStore, error) { s := &FileStore{ Accounts: make(map[string]*Account), mux: sync.Mutex{}, - globalAccountLock: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), UserID2AccountID: make(map[string]string), @@ -416,252 +232,6 @@ func (s *FileStore) persist(ctx context.Context, file string) error { return nil } -// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock -func (s *FileStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring global lock") - start := time.Now() - s.globalAccountLock.Lock() - - unlock = func() { - s.globalAccountLock.Unlock() - log.WithContext(ctx).Debugf("released global lock in %v", time.Since(start)) - } - - took := time.Since(start) - log.WithContext(ctx).Debugf("took %v to acquire global lock", took) - if s.metrics != nil { - s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) - } - - return unlock -} - -// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock -func (s *FileStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - log.WithContext(ctx).Debugf("acquiring lock for ID %s", uniqueID) - start := time.Now() - value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.Mutex{}) - mtx := value.(*sync.Mutex) - mtx.Lock() - - unlock = func() { - mtx.Unlock() - log.WithContext(ctx).Debugf("released lock for ID %s in %v", uniqueID, time.Since(start)) - } - - return unlock -} - -// AcquireReadLockByUID acquires an ID lock for reading a resource and returns a function that releases the lock -// This method is still returns a write lock as file store can't handle read locks -func (s *FileStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { - return s.AcquireWriteLockByUID(ctx, uniqueID) -} - -func (s *FileStore) SaveAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - accountCopy := account.Copy() - - s.Accounts[accountCopy.Id] = accountCopy - - // todo check that account.Id and keyId are not exist already - // because if keyId exists for other accounts this can be bad - for keyID := range accountCopy.SetupKeys { - s.SetupKeyID2AccountID[strings.ToUpper(keyID)] = accountCopy.Id - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range accountCopy.Peers { - s.PeerKeyID2AccountID[peer.Key] = accountCopy.Id - s.PeerID2AccountID[peer.ID] = accountCopy.Id - } - - for _, user := range accountCopy.Users { - s.UserID2AccountID[user.Id] = accountCopy.Id - for _, pat := range user.PATs { - s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken] = pat.ID - } - } - - if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount { - s.PrivateDomain2AccountID[accountCopy.Domain] = accountCopy.Id - } - - return s.persist(ctx, s.storeFile) -} - -func (s *FileStore) DeleteAccount(ctx context.Context, account *Account) error { - s.mux.Lock() - defer s.mux.Unlock() - - if account.Id == "" { - return status.Errorf(status.InvalidArgument, "account id should not be empty") - } - - for keyID := range account.SetupKeys { - delete(s.SetupKeyID2AccountID, strings.ToUpper(keyID)) - } - - // enforce peer to account index and delete peer to route indexes for rebuild - for _, peer := range account.Peers { - delete(s.PeerKeyID2AccountID, peer.Key) - delete(s.PeerID2AccountID, peer.ID) - } - - for _, user := range account.Users { - for _, pat := range user.PATs { - delete(s.TokenID2UserID, pat.ID) - delete(s.HashedPAT2TokenID, pat.HashedToken) - } - delete(s.UserID2AccountID, user.Id) - } - - if account.DomainCategory == PrivateCategory && account.IsDomainPrimaryAccount { - delete(s.PrivateDomain2AccountID, account.Domain) - } - - delete(s.Accounts, account.Id) - - return s.persist(ctx, s.storeFile) -} - -// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID -func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.HashedPAT2TokenID, hashedToken) - - return nil -} - -// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID -func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - delete(s.TokenID2UserID, tokenID) - - return nil -} - -// GetAccountByPrivateDomain returns account by private domain -func (s *FileStore) GetAccountByPrivateDomain(_ context.Context, domain string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)] - if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountBySetupKey returns account by setup key id -func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return nil, status.NewSetupKeyNotFoundError() - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret -func (s *FileStore) GetTokenIDByHashedToken(_ context.Context, token string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - tokenID, ok := s.HashedPAT2TokenID[token] - if !ok { - return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") - } - - return tokenID, nil -} - -// GetUserByTokenID returns a User object a tokenID belongs to -func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, error) { - s.mux.Lock() - defer s.mux.Unlock() - - userID, ok := s.TokenID2UserID[tokenID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") - } - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Users[userID].Copy(), nil -} - -func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - user := account.Users[userID].Copy() - pat := make([]PersonalAccessToken, 0, len(user.PATs)) - for _, token := range user.PATs { - if token != nil { - pat = append(pat, *token) - } - } - user.PATsG = pat - - return user, nil -} - -func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups)) - - for _, group := range account.Groups { - groupsSlice = append(groupsSlice, group) - } - - return groupsSlice, nil -} - // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() @@ -673,278 +243,6 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { return all } -// getAccount returns a reference to the Account. Should not return a copy. -func (s *FileStore) getAccount(accountID string) (*Account, error) { - account, ok := s.Accounts[accountID] - if !ok { - return nil, status.NewAccountNotFoundError(accountID) - } - - return account, nil -} - -// GetAccount returns an account for ID -func (s *FileStore) GetAccount(_ context.Context, accountID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByUser returns a user account -func (s *FileStore) GetAccountByUser(_ context.Context, userID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return nil, status.NewUserNotFoundError(userID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Copy(), nil -} - -// GetAccountByPeerID returns an account for a given peer ID -func (s *FileStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerID2AccountID[peerID] - if !ok { - return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerID -> accountID. - // check Account.Peers for a match - if _, ok := account.Peers[peerID]; !ok { - delete(s.PeerID2AccountID, peerID) - log.WithContext(ctx).Warnf("removed stale peerID %s to accountID %s index", peerID, accountID) - return nil, status.NewPeerNotFoundError(peerID) - } - - return account.Copy(), nil -} - -// GetAccountByPeerPubKey returns an account for a given peer WireGuard public key -func (s *FileStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - // this protection is needed because when we delete a peer, we don't really remove index peerKey -> accountID. - // check Account.Peers for a match - stale := true - for _, peer := range account.Peers { - if peer.Key == peerKey { - stale = false - break - } - } - if stale { - delete(s.PeerKeyID2AccountID, peerKey) - log.WithContext(ctx).Warnf("removed stale peerKey %s to accountID %s index", peerKey, accountID) - return nil, status.NewPeerNotFoundError(peerKey) - } - - return account.Copy(), nil -} - -func (s *FileStore) GetAccountIDByPeerPubKey(_ context.Context, peerKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return "", status.NewPeerNotFoundError(peerKey) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDByUserID(userID string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.UserID2AccountID[userID] - if !ok { - return "", status.NewUserNotFoundError(userID) - } - - return accountID, nil -} - -func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !ok { - return "", status.NewSetupKeyNotFoundError() - } - - return accountID, nil -} - -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { - s.mux.Lock() - defer s.mux.Unlock() - - accountID, ok := s.PeerKeyID2AccountID[peerKey] - if !ok { - return nil, status.NewPeerNotFoundError(peerKey) - } - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - for _, peer := range account.Peers { - if peer.Key == peerKey { - return peer.Copy(), nil - } - } - - return nil, status.NewPeerNotFoundError(peerKey) -} - -func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return nil, err - } - - return account.Settings.Copy(), nil -} - -// GetInstallationID returns the installation ID from the store -func (s *FileStore) GetInstallationID() string { - return s.InstallationID -} - -// SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(ctx context.Context, ID string) error { - s.mux.Lock() - defer s.mux.Unlock() - - s.InstallationID = ID - - return s.persist(ctx, s.storeFile) -} - -// SavePeer saves the peer in the account -func (s *FileStore) SavePeer(_ context.Context, accountID string, peer *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - newPeer := peer.Copy() - - account.Peers[peer.ID] = newPeer - - s.PeerKeyID2AccountID[peer.Key] = accountID - s.PeerID2AccountID[peer.ID] = accountID - - return nil -} - -// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// PeerStatus will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - - peer.Status = &peerStatus - - return nil -} - -// SavePeerLocation stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. -// Peer.Location will be saved eventually when some other changes occur. -func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Peers[peerWithLocation.ID] - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) - } - - peer.Location = peerWithLocation.Location - - return nil -} - -// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return err - } - - peer := account.Users[userID] - if peer == nil { - return status.Errorf(status.NotFound, "user %s not found", userID) - } - - peer.LastLogin = lastLogin - - return nil -} - -func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") -} - // Close the FileStore persisting data to disk func (s *FileStore) Close(ctx context.Context) error { s.mux.Lock() @@ -959,86 +257,3 @@ func (s *FileStore) Close(ctx context.Context) error { func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } - -func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { - return status.Errorf(status.Internal, "SaveUsers is not implemented") -} - -func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { - return status.Errorf(status.Internal, "SaveGroups is not implemented") -} - -func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { - return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") -} - -func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { - s.mux.Lock() - defer s.mux.Unlock() - - account, err := s.getAccount(accountID) - if err != nil { - return "", "", err - } - - return account.Domain, account.DomainCategory, nil -} - -// AccountExists checks whether an account exists by the given ID. -func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { - _, exists := s.Accounts[id] - return exists, nil -} - -func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { - return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") -} - -func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") -} - -func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { - return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") -} - -func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") -} - -func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { - return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") - -} - -func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") -} - -func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { - return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") -} - -func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") -} - -func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { - return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") -} - -func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") -} - -func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { - return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") -} - -func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") -} - -func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { - return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") -} diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go deleted file mode 100644 index 56e46b6964e..00000000000 --- a/management/server/file_store_test.go +++ /dev/null @@ -1,655 +0,0 @@ -package server - -import ( - "context" - "crypto/sha256" - "net" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" -) - -type accounts struct { - Accounts map[string]*Account -} - -func TestStalePeerIndices(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peerID := "some_peer" - peerKey := "some_peer_key" - account.Peers[peerID] = &nbpeer.Peer{ - ID: peerID, - Key: peerKey, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - account.DeletePeer(peerID) - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - _, err = store.GetAccountByPeerID(context.Background(), peerID) - require.Error(t, err, "expecting to get an error when found stale index") - - _, err = store.GetAccountByPeerPubKey(context.Background(), peerKey) - require.Error(t, err, "expecting to get an error when found stale index") -} - -func TestNewStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - if store.Accounts == nil || len(store.Accounts) != 0 { - t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") - } - - if store.SetupKeyID2AccountID == nil || len(store.SetupKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty SetupKeyID2AccountID map when creating a new FileStore") - } - - if store.PeerKeyID2AccountID == nil || len(store.PeerKeyID2AccountID) != 0 { - t.Errorf("expected to create a new empty PeerKeyID2AccountID map when creating a new FileStore") - } - - if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 { - t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore") - } - - if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 { - t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore") - } - - if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 { - t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore") - } -} - -func TestSaveAccount(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - if store.Accounts[account.Id] == nil { - t.Errorf("expecting Account to be stored after SaveAccount()") - } - - if store.PeerKeyID2AccountID["peerkey"] == "" { - t.Errorf("expecting PeerKeyID2AccountID index updated after SaveAccount()") - } - - if store.UserID2AccountID["testuser"] == "" { - t.Errorf("expecting UserID2AccountID index updated after SaveAccount()") - } - - if store.SetupKeyID2AccountID[setupKey.Key] == "" { - t.Errorf("expecting SetupKeyID2AccountID index updated after SaveAccount()") - } -} - -func TestDeleteAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - defer store.Close(context.Background()) - - var account *Account - for _, a := range store.Accounts { - account = a - break - } - - require.NotNil(t, account, "failed to restore a FileStore file and get at least one account") - - err = store.DeleteAccount(context.Background(), account) - require.NoError(t, err, "failed to delete account, error: %v", err) - - _, ok := store.Accounts[account.Id] - require.False(t, ok, "failed to delete account") - - for id := range account.Users { - _, ok := store.UserID2AccountID[id] - assert.False(t, ok, "failed to delete UserID2AccountID index") - for _, pat := range account.Users[id].PATs { - _, ok := store.HashedPAT2TokenID[pat.HashedToken] - assert.False(t, ok, "failed to delete HashedPAT2TokenID index") - _, ok = store.TokenID2UserID[pat.ID] - assert.False(t, ok, "failed to delete TokenID2UserID index") - } - } - - for _, p := range account.Peers { - _, ok := store.PeerKeyID2AccountID[p.Key] - assert.False(t, ok, "failed to delete PeerKeyID2AccountID index") - _, ok = store.PeerID2AccountID[p.ID] - assert.False(t, ok, "failed to delete PeerID2AccountID index") - } - - for id := range account.SetupKeys { - _, ok := store.SetupKeyID2AccountID[id] - assert.False(t, ok, "failed to delete SetupKeyID2AccountID index") - } - - _, ok = store.PrivateDomain2AccountID[account.Domain] - assert.False(t, ok, "failed to delete PrivateDomain2AccountID index") - -} - -func TestStore(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Groups["all"] = &group.Group{ - ID: "all", - Name: "all", - Peers: []string{"testpeer"}, - } - account.Policies = append(account.Policies, &Policy{ - ID: "all", - Name: "all", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "all", - Name: "all", - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - account.Policies = append(account.Policies, &Policy{ - ID: "dmz", - Name: "dmz", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: "dmz", - Name: "dmz", - Enabled: true, - Sources: []string{"all"}, - Destinations: []string{"all"}, - }, - }, - }) - - // SaveAccount should trigger persist - err := store.SaveAccount(context.Background(), account) - if err != nil { - return - } - - restored, err := NewFileStore(context.Background(), store.storeFile, nil) - if err != nil { - return - } - - restoredAccount := restored.Accounts[account.Id] - if restoredAccount == nil { - t.Errorf("failed to restore a FileStore file - missing Account %s", account.Id) - return - } - - if restoredAccount.Peers["testpeer"] == nil { - t.Errorf("failed to restore a FileStore file - missing Peer testpeer") - } - - if restoredAccount.CreatedBy != "testuser" { - t.Errorf("failed to restore a FileStore file - missing Account CreatedBy") - } - - if restoredAccount.Users["testuser"] == nil { - t.Errorf("failed to restore a FileStore file - missing User testuser") - } - - if restoredAccount.Network == nil { - t.Errorf("failed to restore a FileStore file - missing Network") - } - - if restoredAccount.Groups["all"] == nil { - t.Errorf("failed to restore a FileStore file - missing Group all") - } - - if len(restoredAccount.Policies) != 2 { - t.Errorf("failed to restore a FileStore file - missing Policies") - return - } - - assert.Equal(t, account.Policies[0], restoredAccount.Policies[0], "failed to restore a FileStore file - missing Policy all") - assert.Equal(t, account.Policies[1], restoredAccount.Policies[1], "failed to restore a FileStore file - missing Policy dmz") -} - -func TestRestore(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.NotNil(t, account, "failed to restore a FileStore file - missing account bf1c8084-ba50-4ce7-9439-34653001fc3b") - - require.NotNil(t, account.Users["edafee4e-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User edafee4e-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore file - missing Account User f4f6d672-63fb-11ec-90d6-0242ac120003") - - require.NotNil(t, account.Network, "failed to restore a FileStore file - missing Account Network") - - require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") - - require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length") - - require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length") - - require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length") - - require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length") - - require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length") - - require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") -} - -func TestRestoreGroups_Migration(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - // create default group - account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*group.Group{ - "cfefqs706sqkneg59g3g": { - ID: "cfefqs706sqkneg59g3g", - Name: "All", - }, - } - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err, "failed to save account") - - // restore account with default group with empty Issue field - if store, err = NewFileStore(context.Background(), storeDir, nil); err != nil { - return - } - account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - - require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") -} - -func TestGetAccountByPrivateDomain(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - existingDomain := "test.com" - - account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) - require.NoError(t, err, "should found account") - require.Equal(t, existingDomain, account.Domain, "domains should match") - - _, err = store.GetAccountByPrivateDomain(context.Background(), "missing-domain.com") - require.Error(t, err, "should return error on domain lookup") -} - -func TestFileStore_GetAccount(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - if expected == nil { - t.Fatalf("expected account doesn't exist") - return - } - - account, err := store.GetAccount(context.Background(), expected.Id) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, expected.IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) - assert.Equal(t, expected.DomainCategory, account.DomainCategory) - assert.Equal(t, expected.Domain, account.Domain) - assert.Equal(t, expected.CreatedBy, account.CreatedBy) - assert.Equal(t, expected.Network.Identifier, account.Network.Identifier) - assert.Len(t, account.Peers, len(expected.Peers)) - assert.Len(t, account.Users, len(expected.Users)) - assert.Len(t, account.SetupKeys, len(expected.SetupKeys)) - assert.Len(t, account.Routes, len(expected.Routes)) - assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) -} - -func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken - tokenID, err := store.GetTokenIDByHashedToken(context.Background(), hashedToken) - if err != nil { - t.Fatal(err) - } - - expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - assert.Equal(t, expectedTokenID, tokenID) -} - -func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" - - err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"]) -} - -func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) { - store := newStore(t) - store.TokenID2UserID["someTokenId"] = "someUserId" - - err := store.DeleteTokenID2UserIDIndex("someTokenId") - if err != nil { - t.Fatal(err) - } - - assert.Empty(t, store.TokenID2UserID["someTokenId"]) -} - -func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) - _, err = store.GetTokenIDByHashedToken(context.Background(), string(wrongToken[:])) - - assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") -} - -func TestFileStore_GetUserByTokenID(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID - user, err := store.GetUserByTokenID(context.Background(), tokenID) - if err != nil { - t.Fatal(err) - } - - assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) -} - -func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { - storeDir := t.TempDir() - storeFile := filepath.Join(storeDir, "store.json") - err := util.CopyFileContents("testdata/store.json", storeFile) - if err != nil { - t.Fatal(err) - } - - accounts := &accounts{} - _, err = util.ReadJson(storeFile, accounts) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - t.Fatal(err) - } - - wrongTokenID := "someNonExistingTokenID" - _, err = store.GetUserByTokenID(context.Background(), wrongTokenID) - - assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") -} - -func TestFileStore_SavePeerStatus(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - - account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") - if err != nil { - t.Fatal(err) - } - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - } - - err = store.SavePeerStatus(account.Id, "testpeer", newStatus) - if err != nil { - t.Fatal(err) - } - account, err = store.getAccount(account.Id) - if err != nil { - t.Fatal(err) - } - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus, *actual) -} - -func TestFileStore_SavePeerLocation(t *testing.T) { - storeDir := t.TempDir() - - err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) - if err != nil { - t.Fatal(err) - } - - store, err := NewFileStore(context.Background(), storeDir, nil) - if err != nil { - return - } - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - peer := &nbpeer.Peer{ - AccountID: account.Id, - ID: "testpeer", - Location: nbpeer.Location{ - ConnectionIP: net.ParseIP("10.0.0.0"), - CountryCode: "YY", - CityName: "City", - GeoNameID: 1, - }, - Meta: nbpeer.PeerSystemMeta{}, - } - // error is expected as peer is not in store yet - err = store.SavePeerLocation(account.Id, peer) - assert.Error(t, err) - - account.Peers[peer.ID] = peer - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - peer.Location.ConnectionIP = net.ParseIP("35.1.1.1") - peer.Location.CountryCode = "DE" - peer.Location.CityName = "Berlin" - peer.Location.GeoNameID = 2950159 - - err = store.SavePeerLocation(account.Id, account.Peers[peer.ID]) - assert.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers[peer.ID].Location - assert.Equal(t, peer.Location, actual) -} - -func newStore(t *testing.T) *FileStore { - t.Helper() - store, err := NewFileStore(context.Background(), t.TempDir(), nil) - if err != nil { - t.Errorf("failed creating a new store") - } - - return store -} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ff09129bd82..f8ab46d8176 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -6,7 +6,6 @@ import ( "io" "net" "os" - "path/filepath" "runtime" "sync" "sync/atomic" @@ -89,14 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -117,6 +109,7 @@ func Test_SyncProtocol(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -412,18 +405,18 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { +func startManagementForTest(t *testing.T, testFile string, config *Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { - return nil, nil, "", err + return nil, nil, "", nil, err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := NewTestStoreFromJson(context.Background(), config.Datadir) + + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile) if err != nil { - return nil, nil, "", err + t.Fatal(err) } - t.Cleanup(cleanUp) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} @@ -437,7 +430,8 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA eventStore, nil, false, MocIntegratedValidator{}, metrics) if err != nil { - return nil, nil, "", err + cleanup() + return nil, nil, "", cleanup, err } secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) @@ -445,7 +439,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA ephemeralMgr := NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { - return nil, nil, "", err + return nil, nil, "", cleanup, err } mgmtProto.RegisterManagementServiceServer(s, mgmtServer) @@ -455,7 +449,7 @@ func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultA } }() - return s, accountManager, lis.Addr().String(), nil + return s, accountManager, lis.Addr().String(), cleanup, nil } func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.ClientConn, error) { @@ -475,6 +469,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn, nil } + func Test_SyncStatusRace(t *testing.T) { if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") @@ -488,15 +483,8 @@ func Test_SyncStatusRace(t *testing.T) { func testSyncStatusRace(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -517,6 +505,7 @@ func testSyncStatusRace(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return @@ -665,15 +654,8 @@ func Test_LoginPerformance(t *testing.T) { t.Run(bc.name, func(t *testing.T) { t.Helper() dir := t.TempDir() - err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) - if err != nil { - t.Fatal(err) - } - defer func() { - os.Remove(filepath.Join(dir, "store.json")) //nolint - }() - mgmtServer, am, _, err := startManagementForTest(t, &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -694,6 +676,7 @@ func Test_LoginPerformance(t *testing.T) { Datadir: dir, HttpConfig: nil, }) + defer cleanup() if err != nil { t.Fatal(err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 3956d96b114..ba27dc5e885 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -5,7 +5,6 @@ import ( "math/rand" "net" "os" - "path/filepath" "runtime" sync2 "sync" "time" @@ -52,8 +51,6 @@ var _ = Describe("Management service", func() { dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*") Expect(err).NotTo(HaveOccurred()) - err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json")) - Expect(err).NotTo(HaveOccurred()) var listener net.Listener config := &server.Config{} @@ -61,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config) + s, listener = startServer(config, dataDir, "testdata/store.sqlite") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -530,12 +527,12 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie return mgmtProto.NewManagementServiceClient(conn), conn } -func startServer(config *server.Config) (*grpc.Server, net.Listener) { +func startServer(config *server.Config, dataDir string, testFile string) (*grpc.Server, net.Listener) { lis, err := net.Listen("tcp", ":0") Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromJson(context.Background(), config.Datadir) + store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 5f8545243a0..7dbd4420c10 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 387adb91daf..225571f624f 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1004,7 +1004,11 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1065,7 +1069,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} @@ -1127,7 +1135,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() eventStore := &activity.InMemoryEventStore{} diff --git a/management/server/route_test.go b/management/server/route_test.go index b556816be7a..fbe0221020a 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir) + store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1737,7 +1737,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) - //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 85c68ef4488..cce748a0f84 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -915,6 +915,28 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, return store, nil } +// NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. +func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewPostgresqlStore(ctx, dsn, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID()) + if err != nil { + return nil, err + } + + for _, account := range sqliteStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) + if err != nil { + return nil, err + } + } + + return store, nil +} + func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 64ef368312c..dc07849d9bf 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -7,7 +7,6 @@ import ( "net" "net/netip" "os" - "path/filepath" "runtime" "testing" "time" @@ -25,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/util" ) func TestSqlite_NewStore(t *testing.T) { @@ -347,7 +345,11 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -367,7 +369,11 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + if err != nil { + t.Fatal(err) + } + defer cleanup() account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -415,7 +421,11 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -468,8 +478,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -519,8 +532,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } existingDomain := "test.com" account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) @@ -539,8 +555,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -560,8 +579,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/store.json") - + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } id := "9dj38s35-63fb-11ec-90d6-0242ac120003" user, err := store.GetUserByTokenID(context.Background(), id) @@ -668,24 +690,9 @@ func newSqliteStore(t *testing.T) *SqlStore { t.Helper() store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newSqliteStoreFromFile(t *testing.T, filename string) *SqlStore { - t.Helper() - - storeDir := t.TempDir() - - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) - - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) - - store, err := NewSqliteStoreFromFileStore(context.Background(), fStore, storeDir, nil) + t.Cleanup(func() { + store.Close(context.Background()) + }) require.NoError(t, err) require.NotNil(t, store) @@ -733,32 +740,31 @@ func newPostgresqlStore(t *testing.T) *SqlStore { return store } -func newPostgresqlStoreFromFile(t *testing.T, filename string) *SqlStore { +func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore { t.Helper() - storeDir := t.TempDir() - err := util.CopyFileContents(filename, filepath.Join(storeDir, "store.json")) - require.NoError(t, err) - - fStore, err := NewFileStore(context.Background(), storeDir, nil) - require.NoError(t, err) + store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename) + t.Cleanup(cleanUpQ) + if err != nil { + return nil + } - cleanUp, err := testutil.CreatePGDB() + cleanUpP, err := testutil.CreatePGDB() if err != nil { t.Fatal(err) } - t.Cleanup(cleanUp) + t.Cleanup(cleanUpP) postgresDsn, ok := os.LookupEnv(postgresDsnEnv) if !ok { t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) } - store, err := NewPostgresqlStoreFromFileStore(context.Background(), fStore, postgresDsn, nil) + pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil) require.NoError(t, err) require.NotNil(t, store) - return store + return pstore } func TestPostgresql_NewStore(t *testing.T) { @@ -924,7 +930,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -963,7 +969,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") existingDomain := "test.com" @@ -980,7 +986,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -995,7 +1001,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) } - store := newPostgresqlStoreFromFile(t, "testdata/store.json") + store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1009,12 +1015,15 @@ func TestSqlite_GetTakenIPs(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + defer cleanup() + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1054,12 +1063,15 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + if err != nil { + return + } + t.Cleanup(cleanup) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1096,12 +1108,15 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) @@ -1118,12 +1133,15 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") @@ -1137,12 +1155,16 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - defer store.Close(context.Background()) + + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - _, err := store.GetAccount(context.Background(), existingAccountID) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") diff --git a/management/server/store.go b/management/server/store.go index f34a73c2d41..041c936ae56 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,10 +12,11 @@ import ( "strings" "time" - "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" + "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/telemetry" @@ -236,23 +237,29 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromJson is only used in tests -func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), error) { - fstore, err := NewFileStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - +// NewTestStoreFromSqlite is only used in tests +func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) { // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var ( - store Store - cleanUp func() - ) + var store *SqlStore + var err error + var cleanUp func() + + if filename == "" { + store, err = NewSqliteStore(ctx, dataDir, nil) + cleanUp = func() { + store.Close(ctx) + } + } else { + store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename) + } + if err != nil { + return nil, nil, err + } if kind == PostgresStoreEngine { cleanUp, err = testutil.CreatePGDB() @@ -265,21 +272,32 @@ func NewTestStoreFromJson(ctx context.Context, dataDir string) (Store, func(), e return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - store, err = NewPostgresqlStoreFromFileStore(ctx, fstore, dsn, nil) + store, err = NewPostgresqlStoreFromSqlStore(ctx, store, dsn, nil) if err != nil { return nil, nil, err } - } else { - store, err = NewSqliteStoreFromFileStore(ctx, fstore, dataDir, nil) - if err != nil { - return nil, nil, err - } - cleanUp = func() { store.Close(ctx) } } return store, cleanUp, nil } +func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) { + err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) + if err != nil { + return nil, nil, err + } + + store, err := NewSqliteStore(ctx, dataDir, nil) + if err != nil { + return nil, nil, err + } + + return store, func() { + store.Close(ctx) + os.Remove(filepath.Join(dataDir, "store.db")) + }, nil +} + // MigrateFileStoreToSqlite migrates the file store to the SQLite store. func MigrateFileStoreToSqlite(ctx context.Context, dataDir string) error { fileStorePath := path.Join(dataDir, storeFileName) diff --git a/management/server/store_test.go b/management/server/store_test.go index 40c36c9e010..fc821670d65 100644 --- a/management/server/store_test.go +++ b/management/server/store_test.go @@ -14,12 +14,6 @@ type benchCase struct { size int } -var newFs = func(b *testing.B) Store { - b.Helper() - store, _ := NewFileStore(context.Background(), b.TempDir(), nil) - return store -} - var newSqlite = func(b *testing.B) Store { b.Helper() store, _ := NewSqliteStore(context.Background(), b.TempDir(), nil) @@ -28,13 +22,9 @@ var newSqlite = func(b *testing.B) Store { func BenchmarkTest_StoreWrite(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Write", storeFn: newFs, size: 100}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 100}, - {name: "FileStore_Write", storeFn: newFs, size: 500}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 500}, - {name: "FileStore_Write", storeFn: newFs, size: 1000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 1000}, - {name: "FileStore_Write", storeFn: newFs, size: 2000}, {name: "SqliteStore_Write", storeFn: newSqlite, size: 2000}, } @@ -61,11 +51,8 @@ func BenchmarkTest_StoreWrite(b *testing.B) { func BenchmarkTest_StoreRead(b *testing.B) { cases := []benchCase{ - {name: "FileStore_Read", storeFn: newFs, size: 100}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 100}, - {name: "FileStore_Read", storeFn: newFs, size: 500}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 500}, - {name: "FileStore_Read", storeFn: newFs, size: 1000}, {name: "SqliteStore_Read", storeFn: newSqlite, size: 1000}, } @@ -89,3 +76,11 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } + +func newStore(t *testing.T) Store { + t.Helper() + + store := newSqliteStore(t) + + return store +} diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json deleted file mode 100644 index 7f96e57a8f1..00000000000 --- a/management/server/testdata/extended-store.json +++ /dev/null @@ -1,120 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "CreatedBy": "", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["cfefqs706sqkneg59g2g"], - "UsageLimit": 0, - "Ephemeral": false - }, - "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { - "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "AccountID": "", - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", - "Name": "Faulty key with non existing group", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "UpdatedAt": "0001-01-01T00:00:00Z", - "Revoked": false, - "UsedTimes": 0, - "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": ["abcd"], - "UsageLimit": 0, - "Ephemeral": false - } - }, - "Network": { - "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": "", - "Serial": 0 - }, - "Peers": {}, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "admin", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": ["cfefqs706sqkneg59g3g"], - "PATs": {}, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "AccountID": "", - "Role": "user", - "IsServiceUser": false, - "ServiceUserName": "", - "AutoGroups": null, - "PATs": { - "9dj38s35-63fb-11ec-90d6-0242ac120003": { - "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", - "UserID": "", - "Name": "", - "HashedToken": "SoMeHaShEdToKeN", - "ExpirationDate": "2023-02-27T00:00:00Z", - "CreatedBy": "user", - "CreatedAt": "2023-01-01T00:00:00Z", - "LastUsed": "2023-02-01T00:00:00Z" - } - }, - "Blocked": false, - "LastLogin": "0001-01-01T00:00:00Z" - } - }, - "Groups": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Name": "All", - "Peers": [] - }, - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "AwesomeGroup1", - "Peers": [] - }, - "cfefqs706sqkneg59g2g": { - "ID": "cfefqs706sqkneg59g2g", - "Name": "AwesomeGroup2", - "Peers": [] - } - }, - "Rules": null, - "Policies": [], - "Routes": null, - "NameServerGroups": null, - "DNSSettings": null, - "Settings": { - "PeerLoginExpirationEnabled": false, - "PeerLoginExpiration": 86400000000000, - "GroupsPropagationEnabled": false, - "JWTGroupsEnabled": false, - "JWTGroupsClaimName": "" - } - } - }, - "InstallationID": "" -} diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..81aea8118ccf7d3af562ddece1f3007f1e8fc942 GIT binary patch literal 163840 zcmeI5Piz}ke#beYMN+aQ#_=W-C(-UmiDMI%Y*LhD$qU1kW!f=p%ZX%R*D!+NhyTukw&U@-<(bEEZE!smbz2#6Km;T zzyElfJ{<86=v@kZ9`y0u#5wD4N4aqJTQ2iP_S^B_9lJk%F*Q1#9sPQAGr2Q*CV52C zM!rt`KJg{}qsQ6TQ}X4LQbISX?2-MTK@C*3wquH8wQia1maP`t_Y<9|*V6LGD^ll@ z$kwZm)E2XB)6g2KR;e)CRvmMf8MfNCnWgI0_}t@b8>Q8|CFREY)zVLu^1;=Vl{@Q7 zIli*;nX-4Xa_;fwWLjQXkv?B_wCx6qFA$%a_1biI^~$Z%z;h{QsI_IqVUL{I;I!VM zp*5MGwWHZPtm?w>)7YbyZfTBg8fsN@NR+=idXvqTD;A^ARdsvc7xq&$CoI>{Y=@(z zCIp2)+PHChb>p6LvvkjGU2b;m&idxvjny0Lca?H&SLni4TbkQj5wa`H<4aB*C<3PwX|EloO^QVXj+zK>FGUT20}!<9E&R1EgR8ks2WeLme~l@ zs@p0xc2BP`0ky0c?IyGI%52%-T4fC;P8doLi>qM^;bY7jUcSz^8)jvf>o%xW=(}Oo zbt9zkz*?fVq1$$w7v{F9TrFK&-MV#GnbBJM%xqaV99H*i&tf%#TvkLHvtlws!p?rn z@BG1b=CN}Je~hVx*ybYE<`?u0E_W;?Uy=tIoY0zD5N{LP9jD;Up74-MOBD;GEw6cY2^cq zQ=G!}MmBTWEj5&&w|v8Dl3A)drtO3cP$bq()wCmGyUa3JBa&1xE#C;rqExG)IXs9( zqUDGN~+v+LLj6K~ljV3e5LAweNYOnu1c6QA= zluMDr+E*)PH89j}rta)V^c&>VnJF?N?(jL>oEV|~5p{SsmwA?(b8DwCN~36OFO`-r zPfMMXfe#9XdpLA<6Vv@pshb%JwyuNfX7#wfCm$S1%hS`+Q&~8xu3Y!iXw+?WGoHhD zg$GI~*6K`c z=uMr@9h8nS-RCf)F0mJ|rQ&ExzCJxDWQI;r=zdOZC)09qO6o{~Rr!%vt(r|uH~hmR z{nq4Xd6KRj`H5sjZUyDs7avJ!d1^}f1-UwMAijkEJ`%MaKRIeU{$%)|>9KcC&mX%- zDtQoAH?3f1O>;Oer)B9)jogNJ(vUm!gNt4z|E|~cuJiYE=-v1GeiC{3-Xu5vge-+e z0aAdY8+B@iZVR8PQj{UH;<2P>Mi__+Uq6JyeRhFRLjpmI)!Plt@)ikuI(w-0&qD)g zZ#_uyqovgcdHq28Q;POpiT>CSCsk=^dNbmzBT`zUVLlv$h^D=kM$1tOWPX5)30gsY zuDCxjCYL|~Q6~>#TG12q59HkWk(9hPHE68)Caf0xPf5PcP9%OXIYF;d_G*Hjcz^&1 zfB*=900@8p2!H?xfB*=900=yTKxajIV{78})H`EaKbgoJ&x~d=8NROGt`#cF`Q?SV zZEZ0>w@_gh=2jMpD|5w#rNv@CU#L}z+w=MPg}Fk0EtFu*P9iZ*?Ub^AO3)Jz5C8!X009sH0T2KI5C8!X z009sHfme*c>EzbfF#8Yg`v03m_M2C%gJ>=YfB*=900@8p2!H?xfB*=900@ADm`X{~f+ZJ3L{!_;$4Yd4&I zg}%f`d8j)(3Vo}T!XD{1|Gof4{3X4`=qo7bdlcwa_QG6#d9JXc%;zsIEM1(Zd(ZQW zE5)Vd`FZzl_W2?gS(z&=4;DGNmOR}aKPR3_K2P_^)AOu&y+5OESE@6gJU%0FV()px zE)7NO%6CNU%2ih5%s8@_2~DNOY7guS`6c_ouEFYyEA{#M%qRT%|48;IQQt!M)XfB*=900@8p z2!H?xfB*=900=|~@b&*N_bG_=|1g65{(t(>KRiGH1V8`;KmY_l00ck)1V8`;KmY^| zF#&%4AM5`^tYCBv1V8`;KmY_l00ck)1V8`;KmY`Y0M`F70}ucK5C8!X009sH0T2KI z5C8!XIQ#^#{y+RWMh`&%1V8`;KmY_l00ck)1V8`;KtTNd-{e0fvg`DM2MB-w2!H?x zfB*=900@8p2!H?xygURtZ;ecDZM~7nOdQW#zI`GyQCp}jRhKTz&n*>e+jE5itIVzB zt4nkF`Gt9{QkbXT3oH)vU3hJqS@%B~`0Il7`-Tg1g*;WO6qYXL3m1#a=jIFfr3=gM zj~i+&y>niorpeT_OK#IxRjV>rc;rAjf#`pifJcf;yfB*=9 z00@8p2!H?xfB*=900@A9M}U9-|1jGf;`{%H*?hqIAGQSoAOHd&00JNY0w4eaAOHd& z00J)^f$!z_|D^0X-T(j6^%V^Q0T2KI5C8!X009sH0T2KI5CDOflmOlTKen|sarW{X zmoKkWKP@iX#l>gc`ajI2{7v&V`%v55DOK;9H`zM<6+)5z1YvIef|Ac)6wkBbmH+jE zn^~C47kV=NQhgcx`oEO@Q-Yp&fB*=900@8p2!H?xfB*=900@8p2)tqhPA9jQsCyr+(`oC#^zCO%v2&_J2w%KIYEwkO4r(XoP z|H;P4;5v)XU1vem`JqI0o``Sf!f!3m00ck) z1V8`;KmY_l00cnbl_Bs}a#4!gp-9GO6zd~$a&fG`#6o={``?L4={JdE-(>&g*ssQa zckKT7#nk9{cJ%Af&E#Jvcale>-;8{n_b11|@ru;BB(nADBelgW+cdO>s#Pk?wpGX6WrnTtFE3E_s;hwS z_Z;ypIKH+~TD@CRZmeG|{ZuI*+|;siXI&}BS5`h#_BvF~J>Hy5%PT9==c|sk-C*$r z;#0F;o9?b&xm6P9`c5m0(L-4|LtQT`4twOx27T}j4Xw%itR2nXVO3Xqej0n!(k;!= zO+&3}4vF$tM{ly(a>ZiwxvFmO`@(*T=7i-Mn(c5~pe6)`Kiar)dv)WUaLZC!45 z?aun<-Hp{7>vxrMZC4nNt+q7PvDJE6xprrxbmRKES0GS?vQfHL+9<8Bl{OX8HJj8# zoz3O8C(`nwEOkZ$E9Hf3ZlY$}aL6vj@npY=^vbzLAyTRGZh1P%v6OsC9yD&;dJ5ug z;>o3>X<3$~r}u=#h(z&nENTxTEvm8589k+0W+PCVZmZOr`hrBo1eT{B6%*rm;Tu`gfX2Y!OMo683Ekrf1+jg55=9Z^iEnQpP zx^-8Xp)Jap*|KgptnOKs#cHIFvLe!$6_XhfcJ^C-=MT0bkDcp_Vhkz7HW#razhGca z5GMMwww02dkTD2&) zv1l*secNG8J_26z8nH6b=C(yFYYP;e@hMVZ3m*3Qkc1^1|ZL6m|Gxl`LG@8sH zpXw?=l&k5DZd1azs}y(}+IAN$*RN_jn$zYP)y1s=fa6aMv~KP%cFdYhSIH)xc1@nYyzd(QlAbXQs%A__*irU1Egx2f?9z zS1$9cFy+=>5N~69skD4~TI!q(d{8jl!=baAnC^E<-ONz1bsbbUtH<>{`QS)eo}QMT z%EDQ7<+`6nqi(C4@f^M@JWxV$WI*_G&7Arjna;FZd=^vfNwwG>O?ty-OQue9DViyf z#|Q#j++3&}6L;ITR%dEMZ|ZdJpmdDsK8G1~iM@a=6-QI@_31%Vp3o@@-Os7*WLhpx zNgXM$DnGocRkNw-hJSdZ-trWH(vX=3K(v@E@;k=yW28ghqzaM7#e-}Rc_b^d-1 zz59OOPa+TBo8-oykfrb_KnielqfX7xZQ)Z@iZWzYJeKrK^8!)f>xXc-&n^&ZNFZpj zdb^=n-n@QKXAjl>$z>qztvM-vw6yvluOCQ%O3~gc(H|S)q$&+fZ$_MTL`rKk%!h*z z(OlNjXvs@~%nxueK`W@w758Vmf}L8D|&+dft))(l9Jb^28}h}gw=xoDM>5s z3}0NA4mQ@IeFx(k-ty;+lZF{YjB?f!He1D1(bpZVWQCHCIR4~>;E9EM)E0*5k%i?gh z2PPPvE+AjRa|5B_@A`*8JkEACtFq(y5xd{a`i-V{8SUGKI|1QKzpmo?merw7KYKN_ z_QRC?lfod&>}RVM`qQRs@22GoC#B9t`t;UKRx(7;@H zFEF(@`_8j=A-(Dg?)wivJDHYOPf4Arz<63_L$BRA~+fg{XTgY2Y zhbknK3aTvUo=9?9K6Og^qU0{ELxSN7-;P>WI4kOGx>b8q3O-tVbBM8)evQ)7G31M= z2t|GoHUsBkn{A8q?AN-jS80tVN<~&e6mcE1Vg?KImT5UbA%37*no$qxIoJV-zUibA zp_tx{yl&5<4#zJ=-1%-E-cgR+_lxswbiSo+=aJN^379&T##3iRqKwuCb7me+D_mNO&a&IicsZ z5Lf?X<2K)9P%AsEvg<4MU@dxNyn!X!7u}KVq9bjEH)Z{jo+fkW-+0!f#;2mXI-?W4 zb@VUYNlooUZ@B22f4B*6_Uo)q($$#f2cw(1&vl{V@%zU`Z29~>pV{nD(Q!XaCvfrVORM$ zIg^%iIq8XaAu-~eqZOZu_BWatwTZA1#7rMT4e&X|1UxNka&Oq z2!H?xfB*=900@8p2!H?xfB*;_UIP63Ki2<;SHI{P2!H?xfB*=900@8p2!H?xfB*<^ z0$BgU4nP0|KmY_l00ck)1V8`;KmY_l;P4Z`{r`tw$LJvlfB*=900@8p2!H?xfB*=9 z00`jvKWqR5KmY_l00ck)1V8`;KmY_l00a&{0j&QIzmCyE5C8!X009sH0T2KI5C8!X z009ud{r|855C8!X009sH0T2KI5C8!X009s<`~-0Q|M2S=Jp=&|009sH0T2KI5C8!X z009sH0j&RF10VnbAOHd&00JNY0w4eaAOHd&aQF$}{{O?TWAqRNKmY_l00ck)1V8`; zKmY_l00eOTA2t92AOHd&00JNY0w4eaAOHd&00M`f0M`G9U&rVn2!H?xfB*=900@8p z2!H?xfB*>a`~Rivw+VXU0RkWZ0w4eaAOHd&00JNY0w4eaAaM8yB&B3x0)PL1_!NsC zfdB}A00@8p2!H?xfB*=900@9UF9EFodx_u}1V8`;KmY_l00ck)1V8`;KmY^|9|8RR z|KZaqdISO>00JNY0w4eaAOHd&00JNY0=)#V{_iD%V-NrV5C8!X009sH0T2KI5C8!X lID7plVC(-UmiDMI%Y*Li=$@Av+Fe!utoy*D$Iwmy1bdQ7X>PTladeBzaaq9p!8(-MiqJM{lD{THvd z=))OtLPtgV-0$N%iL1`vjrhXp?|qrq)8CK(?%c=YtI3h^^vE|OTf=)Jmxs?N#?Uv3 z-zUDNfAqTYYEr#XnrG(TdW#!L!Zs zQ@z0)*S3tBZj?&Qb#>3)XO^osUFPU!xp(gQ_07WCc0s$hal7zyt$2JjMeWgsR_t9_ z@p3V2teAPWHJMUZR+KN+JY%=UdKc)OnhyK4y>{zCp>JE-W$JBF^Vm~wHflCJwTwCw zS$l@N$I1Z=k;a}jOvms{+tSO1N230_XV%$lvE(rNT-J9Fg|J95JXx-0xSo%ex)2rq zaP!{7wav%c{lep*cbVDsM;lw)n``$rwzXnqpX%fwVems+KkaKXJ(70<*{mLdk(7*G6JblR&&uKOytWREittuj`?YqAGYjqjBTpX~koVlQ$7~ zBbRy2hAt(T4PjVyGE05Wc0Jhud1BSpZ8s*i&m4=@Vo4?25k^qtr5ZKE^MhDSs-qs; z)emT7a-4z*knSx@?*lnZr z^{cy{;WhmX^6M^Rv1u*~lx)jlB?=LebCC17kI;F9=_})~&zdH&(v1cU`@kS2M~=py zOx!Dkw&d=~mgLeYv;2rEw=cCBM716w{Gg@|$$B1HJuZ2~yM~H5^MwRI`vu;w` zQRyD0dkQmZ6MF$$%8w+~yVLzbrqo2P{W-ZioKo{sN=u2XN&50#WUHKqKDTpc+OAt8=KaqAJuaoZ7#;p3LaVVj{p4$gG)AgpRT z(aM_EaA8ivG3y4ojj&P49g5&$mdU@H6|?R9MGhT>-xo>b;X9Lp_#;`09|cGOo@rI7 z8>Y*Bs!mac%*u}?9V@~}l>7PwKSZ?bbqv~$QXmrnu1C;{>I=lh!kAnF z1w@lPi0#B1=$^=#YePwOeX8GF6DF+Z#HV3@pPfkjXmWzymGtcd{on-xAOHd&00JNY z0w4eaAOHd&00JOz1cBCy^7_ui!>PB&c78T7dVX|dbad3;SMOGGrRD7M!rZQLJv+Bh zVoP%?3;C6~{KDe(d^Ve_l=8dt+4+UJTy`!yujLk3v-#ER;?B7(fT{SXzl~iQUSm0I8Q)y0LY*{j)^*_nri zyH9D?t}U;vt=+gW^NWu^+4%Z5j}nP->Zg+aHbFmlfdB}A00@8p2!H?xfB*=900@8p z2)twjE)DODo#gsMu>b!qk^b%_8z8z10w4eaAOHd&00JNY0w4eaAOHd&a3q1V!&778 zxq{&N|8ydqKC%chK>!3m00ck)1V8`;KmY_l00cl_kO;IaB|WjT^X}Bv*4F6g&GVz9 zC%K8fHov}HSe(B#mtVefeQx3U_59qee6BEe`}X4E`r^vsdg0cs+pJkuFliL`q}F%`Nie=`QY*Y`MfW(GM8KKFLHb>S^BcU z9RJa>S^81|{aoSiAAb_u|Nmnm{l`I4BgO&(AOHd&00JNY0w4eaAOHd&00JN&2n-KT zjN$wLf(DKt00JNY0w4eaAOHd&00JNY0wD165jg4h|NZCx>5u;51p*)d0w4eaAOHd& z00JNY0w4eaATYoL{QLjd{|~T&F*FbW0T2KI5C8!X009sH0T2KI5Fi5B|HBMG00ck) z1V8`;KmY_l00ck)1VCW$31I&}_&UZ2K>!3m00ck)1V8`;KmY_l00cn5|Nj5b{kx~8_f#>z&8u?6;`T{Fsc(|WPK`S<^a(*Kx9e@kz8 zfdB}A00@8p2!H?xfB*=900@8p2>cKTTpr$0PI6&m$p8KSC%Mdk{r?X^t>^;?fB*=9 z00@8p2!H?xfB*=900_hg`1}8pJhZ^~|0fajpZ{0V8wvWs3j{y_1V8`;KmY_l00ck) z1V8`;K;T3KTGy1Zot=p*H($SbbEW)ge%Z}mKk6a>libSRvLCV!jIF&wdE36vHt72X zdHS-!-29T3&93sFEBxL6o2}90WiB1V8`;KmY_l00ck)1V8`;K;We!;NSnp{{N*?89ECBAOHd& z00JNY0w4eaAOHd&00O5bFp>VB#H8}u#JTU%|9bAP$A5S3n~|;IzZu>e zKBN40=$pjv6JOImdR=)nsouPxBuuN!p1PmaOpobC)3f=xUUlqd!_{-a@wL{}D=GD( z6{U59XPf1xdV@KxZ5cJ)D3zG&>YlyNELZn`d4X=0gYN%%+vwec^Xr?1we5m-Z{v31 z=UVajt`@aN8(OheIr|q>{4+^}kKWSw> z^ib3;)6k2W$DVq#(HO#0%cwJvwP(0{tQ=@hq_L+B(=j~Lw)C>$k*NRfnRPZ>EIEum zm-XF4AuLi1PnN40uIJkVbs;MJ;pV-EYnzX?`-R6r?=rLNk2bcpH`nfMY-`2JJ~tj$ zZy37g>eZrl=h0^2-rbF`K%@w5vv8-dS=d-FY-xOGE~$wIo5}1>q}1!G(i(}Z)Gy@v zCTce=kL;44Pj;I~r(9qZJe4YMm#0R~CDj{hzj+hXljDbp=Qqx#R8>{Jddw|`C-S#* zaeLrtag9kWhDvknTBJ17)oEM@W{GiRMa^o~nPZk_i-y1QfFifam|~q+w=?jmZ#k=+*#Xsu&vF|A5Wc`Et;0cs-a~$ ztU~%IYCMfOJu)R>@30{Tf4mii>;hlZ!;mDlnV2n!f{{6mcbZd!#m$MIpY1lMPC4OY zsC0)pJs3->msG`!EH)^*&uGLXJ>IWw5{OsoCuDw^N^y((bv^V^R0XeRG;VP`?Kq3$ zO#}|aWnQzPO9^H}SWBI3P2aO!Pqr$aShaQAjfw3u$6~cuQpt9NW{bR3qh@#&+o{K- zI_j}q{eVVB4k{+~#Os;dG`-iD@lm*kuE*;B3^XbH(InYr zPOKlHK*O=i&5{?JZhVMge%-K|6{F-eosRO%IxrpEsxymxYM=mKu42~OU6FB@De%^e z-8NcZzq;!gUenJYSL`wto4c|=$+j$3BCjsZcsZX>e4Ph{zA_H`tZ5P}-DuFT4-8Ur z)h!4vzg%;-^JDyTij4`T<(lzMYo zXa>=kl@fW3D6sWf3l)3B-KJ|)nO-yN zCbb=v?qRy8FrzlH7qF%LNK(B!-EYYwHIZw7PVNq;)clmvQX;Do;Z-l&b;Gp8>6vb8 z5?P^SU`Ha6tSG3UnECQUC8bVHDZe6DM-D_th~rS)dPH*EcEtMixD{>KX6TQDGo3sL ztJ+Sq7^a0;nA32~x2jZ{QNnSf3F|rWX_$7}qyBDtV)PFQHSs~>Y;OEFW1o)Q z82$IuUyc54YHFmI{I4_rIXrphUxxlZ@gGA*;)C?(lm9*O50n2iF@5fzsqi0yy}EdB zz44Qj`Y5NgR-%dPdr-dNpy7we!{#Zvb!X0WPv~wTF)?_2^f%z$A|2u4*)OJ2YA&aI zd67@{Zu#yME$#ri1>&xtJ*a4|Njv2-b4!ld@VCWsb_gb#o&g{s5xN0s`Un0Y67OZZ zhEv)L{fOIbX5B{9xs4XK5iCIbtzTPlUE6AD&`0lv)_T6&DycbwnTzUJbtB_&!x#0NY&o89Z zwTnt?Dl(piT{BCEy3?eMksh9yWjXboZab2*+l9j2v{WIPR8(a#^ITC=>cxx7mxW+! zEeXmMz7@AFIVqvVUI2t|G#HY4Za+D(V_ENa~~%d|)1 zrD8iFinyL#vZIZ8!*;x=kbj~(hETwZ3+{kXd>&souZX&jp3`2F8$+&)&A)NU1-i#pwCC<@jA}&vnuMg+3+| zmzL~>GCJJPvvWV<3)AQOZDS!R(Sq8;Nc>8Ub3&)JkU;-r^K`-sG(!Nma@mdVY zgcFPRFMdU~jgGV>Uz8OkLrrF`y?)f9Mo`fMol!$?p8d;UQB!%XGhB47KZ1lex^>ni z8E7o@gYk>Hr@B$`?7j0mwtVp}S8h;Y@cwFCv0+=MP^*0Pbe3*P_!kUA{~(r={&kA~ zhl{5{-(SGBpZcCr?|UbIw);8v(@9ld>FZ30e2?2(bjhSwI(_d}S^AUR9O}_<+%{{;ygK>!3m z00ck)1V8`;KmY_l00cl_a0&R&|10V56ZC@@2!H?xfB*=900@8p2!H?xfB*=9z~B)W zR)!N3`2PRkDHbCF0T2KI5C8!X009sH0T2KI5CDNr0@(j|62UnLfB*=900@8p2!H?x zfB*=900;~o0et^|@N|k1fdB}A00@8p2!H?xfB*=900@9UCjspLJBi>N1V8`;KmY_l w00ck)1V8`;KmY^=j{u(kA3U96L?8eHAOHd&00JNY0w4eaAOHd&&`IEb0dP~Nf&c&j literal 0 HcmV?d00001 diff --git a/management/server/testdata/store_policy_migrate.json b/management/server/testdata/store_policy_migrate.json deleted file mode 100644 index 1b046e63247..00000000000 --- a/management/server/testdata/store_policy_migrate.json +++ /dev/null @@ -1,116 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfefqs706sqkneg59g4g": { - "ID": "cfefqs706sqkneg59g4g", - "Key": "MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=", - "SetupKey": "", - "IP": "100.103.179.238", - "Meta": { - "Hostname": "Ubuntu-2204-jammy-amd64-base", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "crocodile", - "DNSLabel": "crocodile", - "Status": { - "LastSeen": "2023-02-13T12:37:12.635454796Z", - "Connected": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K", - "SSHEnabled": false - }, - "cfeg6sf06sqkneg59g50": { - "ID": "cfeg6sf06sqkneg59g50", - "Key": "zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=", - "SetupKey": "", - "IP": "100.103.26.180", - "Meta": { - "Hostname": "borg", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "dingo", - "DNSLabel": "dingo", - "Status": { - "LastSeen": "2023-02-21T09:37:42.565899199Z", - "Connected": false - }, - "UserID": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "SSHKey": "AAAAC3NzaC1lZDI1NTE5AAAAILHW", - "SSHEnabled": true - } - }, - "Groups": { - "cfefqs706sqkneg59g3g": { - "ID": "cfefqs706sqkneg59g3g", - "Name": "All", - "Peers": [ - "cfefqs706sqkneg59g4g", - "cfeg6sf06sqkneg59g50" - ] - } - }, - "Rules": { - "cfefqs706sqkneg59g40": { - "ID": "cfefqs706sqkneg59g40", - "Name": "Default", - "Description": "This is a default rule that allows connections between all the resources", - "Disabled": false, - "Source": [ - "cfefqs706sqkneg59g3g" - ], - "Destination": [ - "cfefqs706sqkneg59g3g" - ], - "Flow": 0 - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..0c1a491a68d58e019b5256b60338ab8b5f5e40d4 GIT binary patch literal 163840 zcmeI5O>7%SmdDwmB~r2}#_=Q_CDDXd;@CtanawYY9vEI)q7#j6*^w+b9wTTr$s*Mj z*-dvhWov>gKsyO$2UyGwHpc~KPji^V%wmAqoaeOI!(jKaK(343!(8^3!vdMZzW!u4 zKdcW=7MABv%Otz%Rdvzx=C()cwIJFqONK!$w*{EFNHWR^jC>QC=_~={-2}& z?&EcObId)USI+l-)Y~^hi`qYqo5JyLO_^8X-%kDZ#J#DD6XR3y@vp`=!dv6#!pAsi z?5oi4LtoNAdc61Y1b=yk3&~22J<=aEsez)@F;weVY-nnyt&4*Fda677QiT7oz;!QK z*>dfX*k+opDpFIFs#T`zqM>dxMHf3d(?q#8Jon`CdZ~E3lvrI`DSebE?_EtfadRzE z9$s1bT-iBUIr(^FI>HwU+-F5Ysx;Z~0>e||PMdBQuUs#UJXhizwYHov*dt@!JFRo5 zNG;}OZAtnTtJyHzH1?=1Ymym&7*Y75-mG0WDOU^IfT-&(4UR+(fohaA0EnVnhTM`XjY?KpMZ?2bC-(PbI zcY@@JJsinR0<>Zq~$0Iz?b6?!G%)k<{o=yanY?lq_)K`t8R!wbs zYL#`78oMJ`nFY0+P&zH9$<_I?VrrE&nRP;6x?f!MTZlQvoZ;o_ywX&w+oo>4S}lDy z)rPG26z*9|&^BaU@0f*6+f1yKt`;}1-%iX)ZFz3KEGq_UIJT#;IzcWctTd(#$@B>u zyKT4gd)t}A&K~?jOwGqO8L&3Dplfi+6BGO;ew4vkS`!89C03d0H>WlWnv->We9)Zw z_rTfh3`s4(Emgi*8V(p^ajNp*;)Tp{iAXurHko-v`Vj1VpT8@vRDy({CEV~MPGv4b`7gVW+`r|y5ToKE3u)9svZ#AW}3p9fuyRcxkga7O0}Dk zVFs~)R9`)+F78mz$lC>^9vVZFJMz$EyD>wv6lOeBwQW(V)ijDbEZ7TuS2tM89IP^h zCK^u_rUlyJ2(&e|)~OnSk!f|&$!|$Yr!G~Ej@DP6DLb;JDlMjvgSHi5m8;9mUQ_(I zYZMflQl*F1)UT);lF>0U$gk^+1%|m_psFegt5S&YIR}5fOa5loOV-5JjZdBfcwI(vzcL8sKq^aWenLG`lwT;G#-k45;?r@1e9%USj0dhf=A zZmXAZ0N>>wD87(0BGTDrPW=wdjXG^}N>uAh)z}VANd0C@rcM(xns$-L@M7SwiB)+> zi0;no#lQ?u8tgtE8)J51+B+T4%&`88QyDp z?3~l_$M%s(9)vYi&6`=%9L~vUYjR5>x8a=B=MLTABG<^j%XPWu{M{USb^X4ZL>|6B z$&NohOEIGWDZr4G1~o(0EuShJWvq;#H z*+X$~9_mRu>p_YiZK*NJ>wD7oDcUg6 za|7IvpykzPi@OtJatRa=W%3}Z7Cgb=Ku(??o8XscM~yYtghj!97dF?~q0slHWAx0$ zS3>lM2MB-w2!H?xfB*=900@8p2!H?xfWT7-bPL=oo3R_SZ%l69iA7IF$D`4xxvs9% zh3ZmjDZ5aSa;b%Em0egUWHW_@Og5j(q*6k?nyI8y>Fk1#T1cf6LjGbZdoi6^%w?9c zg797{eKD0Xm~Jdq)fR0G?a;1~WK~H;wxq08WtyrKmh378tIFypbE~)JF3t(5)M7rn zm|9HD&ClJC^leHze}1W0EMB@a_vyWl*S`Gu%}{8H+R4Sg4$&VTAOHd&00JNY0w4ea zAOHd&00JNY0!NI%+3@D%e)b>i_5U}a_%}zagJ>=YfB*=900@8p2!H?xfB*=900@A< zQwbao&rZ7c71;Ox$3yY>Q;Q%I1V8`;KmY_l00ck)1V8`;KmY_@Bm!N9i^n!M-=5vr z*oa0ipNvNLa}m9mUS2BY(^nQUOILFX*<3EOa3v#@7FJgB`Q>~ezg)U9AFQ(~c{Zy`y$uFhT_U-@aj44uB5SB)Z>|INW zK5VdH{UuT<`cMJ=y=Oh&``Etz|EEy=PcM=h(H9T^0T2KI5C8!X009sH0T2KI5C8#} zKsX$m#OMEA8h8Z(5C8!X009sH0T2KI5C8!X0D+@NV85ULH}C%+JMp{F^i_Jo0|Y<- z1V8`;KmY_l00ck)1fDknPo~Gt<~CVu;T8Hx<(Zk-+p$k88(Hl3S zCu6Z_wayxOy`IYJ54IK7$Q2s7)PC;h{_&0C&1;)%_3mn?qhyuTqg(4jC9_()`=eVe zefR#|)|K@AJIkG>x}3daebc&-$}Fbyi^5WBuA*vBUs`(NbQ_XE!VK zo1mQ<+hI+$ZT__0Cv(e9wNs-nW~Z-czvcXbt$FlmjedT`2XE^wAY z&X@WW^441$lcbQ%s81pF2l@q{Lb8qhDCEX!uJu9vhg+3K<5qnobM--KRsX29^ii&| zaffZLrmt_^EZk_9O2(aHPa(pE!eTnJG-rcfNT*ZTh5J&gwYwm-YI*Z}`E~ZJnyP9l z{l+W(#IE;Gf%e+$L`NKbfx00@8p2!H?xfB*=900@8p2!H?x90dZeg>&4nJ+Co7oSUSK zfxQ=d{r`(l{NLlhI0_v^8$kdBKmY_l00ck)1V8`;KmY_l-~}X*3N!8?_o9rQ53|XG z+$&+P|9=yTe;xnL3&;;$0s#;J0T2KI5C8!X009sH0T2Lzqe$Rfc$3@D$0f!N^!fkk ze+|WJ^n?cpfB*=900@8p2!H?xfB*=900=zq1iG({O>b_#5{<@AMlatu6^-rZTM_oN zl2s(_3CO^E*R009sH0T2KI5C8!X009sH0T2KI5csYMbhv14^VF&BXslXi^#}Te zR9=6ut*}O}(8x6Q^J@UbW|Mww;NIL|`D|lu-b`-f^}4UxTx#xP^ZGv*|6_>$@Bjf2 z009sH0T2KI5C8!X009sH0T4Li1o&`na&TiH(}>0YH#E)tI&|Wj_`jU^yQ$xvxHoli zVtgt-{?+(K`0vA8;bYvd$G!^vKJ+F1qsMzMPwMwB(|BRtBTYVrD~Pwx@f4|OwmR2M;1l7W<&mN&oR6O zCzsbt#oMLC>e@=_qeOY{rj`>o*AnI7m6gwxoeq_gk2j_xe4)U7Ry3qalMOF0JT>mL z>2~qT^^$e2KWk+}^iWQmqpp_|276@8dwp;Y6{*GCtSw33Vl`WPZW?>kmNm(cRYj~x z28o)_hTLNF<*LT$eNC+Fy25UXWccNpl5Uu`Kuz!p-&$Y2QCz>9xK_Grw=Oxqd~=v|d_UE^Q>NuIZ#E>TEJu ziADGv&vnN=D>V!0rirQ@#UQ)1j;9Arq+iZ93M-W=@0F*MoS5J*@uS9#U5{YB#GYI_ z9^rYO`{J%;F;=4WbRuXER$5SFzB4*X)6}M?G+7s^UpsP@S;)!>rPE@XT%9i~rp{QC zSts>rpjCNO}m8M$VHZ|wfYH72nHe|)8PR|yCnwNFGV-_|oFR@a(THL&T zJ26N5nRD}HSut3{u`G?%Ngw5emBzFonLc4-x9xU*Z!2=x*}iCqA^F%Q1GeNA^vr3h z-<;YkXinDg@j-Lymvem#mF_dA>ys1wS)P+Ui?xexGwLx(PmQXZ1cH^C3CW*CB9_Jd ztmXJ9s=}%#8nieoZ7++nPGLI`of)0B=u42>uC=tt*2FDUH~dy*B{o!1)dOPNOjB4h zkW^JQSF>fSRJ$n|byaHxr26Vnb#aG!Mh+?<_0Sla+>wVS+l?8Tr7+{6s%?u>t)@|I zW5HhNySl+z<_MH2{Lvt(FfGsyN1&~#wNBLt3^%KbPJT;LI(4aPbhN(mOxck&RcSGW ze5$Pgt6W`f_L}0yU8BI;lqx;6rhY};kc^I*L9SS5EHHNc0##K}Se3lGZ^Zlat${CE zL1Cnf#vXKJVkJs#>h_LAO7^^m>E5iN?%0yP<+miA4w+&`RR8&X=d=US5D~Lq`*cdI zy~W5w2AV})&aT?&zZvd&X6?(R$f50uRkh|BYA;hZb_4qLavDsvG6FvC0KUr_p@Tti z-@YpueOj2ZYZt7S$(@M^fB7`mo$-8-H{AW9vzHhdbV|KUU$C_uR4=Q~^*wp_ScE@) zn)`ycoK;V*_ijAswt5){@Lm3a;)^3A!k2C4)bGG_rqecOF}1!_jqT8+*Kf9D>NJ<4 znG$&nFR%@p3zdh&-Ht9bnAntCGM(EiJ;ZbmU`9P+&tOZL@d^I@)1#(5zEcRj_Y;+H zgwM=!UCy&AH@u29wI#`ldw6Wnn%pc$(zYWvk*vtBpq%{t78l`XXStt}t0M>EO1Q6M zLF;jogSO+&U-z2PI_Gr!v3(?x2Vo6W^CrVIF>`X-n%t7eZ8#_OxkER&$Tjlsa$W8@ ze>aC-UBB-pk%#Y3vg41>Qp_ko3NU1)LCuhL%cqJIWyq||SkgDm^F%FQzYmA|^a8$y zc!C;hbefXp%#3TWXB*`kwTCiuO*4!PqcNs@jy~R=`;Yq@-q3 zeK-mc&1E%-mb?_m+yFNuXnFP7;_hshTml6|nLLQ91y3+Ikdx=fCivyqQDeIVOcZ zh<`HuKe1m;|5NPriQiG-KiF$^_qqG(TM_=Iz;z4Wz%@Onwcw!cJII~WQ*;~5oEsd_ zgF^1Wu;Qb+03Q_T3m1<+osIB{H@4V{Or@$fsy3;+ypju^!W% zxD{HiC1^UX>$m4D-h7}X=?^+Il{IIh-V~db2ekb6x!^&w3RX@%G&k*N>TA%(!d_l! zq=DJ+o?&Y7-WyNbg>u68uivs-ILuF)FJDiv4>QN%UWs_HGw+p1=Gh0Fs{laz*6&)yEm z>YGd|u@uw4k=N^a(BZhHh`YI)XKu@`l&+U(FL$}Pv0Pj!B`>@-!CyNw$^|*~+QuOG zkKRyTjqpF9$>_?imY7xo6;q7>QZQ zUQXzs7Gmq4Y+N@t8N})qt8TlB-CK(e8E0Tw?F;V6_Rx{G{F}0FNk@~(^RGN@QsYt) zZJp7HUOWER_N1nMsy|!|%s=deR|j=AAZcsN@q@ul-2+{yc>K;uE4F<0wxwLVLi_pU zpkkeK9YC%A-P3uxC}D0GI{v|(PMZ4^{|}p|cH5u9wIBGJ(eP_0-`Vw?Yw-zQER1v} zM1IKSE!r{}mQKIB#a;NbT^t&6;-G8j+fnmJj-z!w=l+=pfBrnz4ck-t{%u!x3gqp> zo(qnzfqWB_HSg+ePfIi{GB+NR-HCT1{QF6+oAd<%zhex90)NKYLFW>5slJkKSn!u8 zH|TejkHd2jKAGg6I2RHF-Z@zDnP7i|nL(TI8^Mt2gRcROKRJ9f2=uM*DxL2&o(JlPu>`V_Fyd7*yAkWpNH_`X!(n7-AL+H2@#KGX5wD14_Iu!rf zUCHAW1V8`;KmY_l00ck)1V8`;KmY_l;HVHd8{VAU&zA<6-~Wf-|2ryLMLR(N1V8`; zKmY_l00ck)1V8`;K;Uo$@cVy<qOK>!3m00ck)1V8`;KmY_l00cnba0IaaKO8rd z0s#;J0T2KI5C8!X009sH0T2LzqeQ^G|DTJ08=^lvKmY_l00ck)1V8`;KmY_l00ck) z1YSG>VJ;ks;q(75o?_7>5C8!X009sH0T2KI5C8!X009u_CxG>TKM@>*00@8p2!H?x zfB*=900@8p2!OzgM*yGyfAMsR9)SP|fB*=900@8p2!H?xfB*=9KtBPj|NDvH7z987 z1V8`;KmY_l00ck)1V8`;UOWQ0|Nq6)DS89~AOHd&00JNY0w4eaAOHd&00R94{ttLC B#c2Ql literal 0 HcmV?d00001 diff --git a/management/server/testdata/store_with_expired_peers.json b/management/server/testdata/store_with_expired_peers.json deleted file mode 100644 index 44c225682e8..00000000000 --- a/management/server/testdata/store_with_expired_peers.json +++ /dev/null @@ -1,130 +0,0 @@ -{ - "Accounts": { - "bf1c8084-ba50-4ce7-9439-34653001fc3b": { - "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", - "Domain": "test.com", - "DomainCategory": "private", - "IsDomainPrimaryAccount": true, - "Settings": { - "PeerLoginExpirationEnabled": true, - "PeerLoginExpiration": 3600000000000 - }, - "SetupKeys": { - "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { - "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", - "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", - "Revoked": false, - "UsedTimes": 0 - - } - }, - "Network": { - "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", - "Net": { - "IP": "100.64.0.0", - "Mask": "//8AAA==" - }, - "Dns": null - }, - "Peers": { - "cfvprsrlo1hqoo49ohog": { - "ID": "cfvprsrlo1hqoo49ohog", - "Key": "5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=", - "SetupKey": "72546A29-6BC8-4311-BCFC-9CDBF33F1A48", - "IP": "100.64.114.31", - "Meta": { - "Hostname": "f2a34f6a4731", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "11", - "Platform": "unknown", - "OS": "Debian GNU/Linux", - "WtVersion": "0.12.0", - "UIVersion": "" - }, - "Name": "f2a34f6a4731", - "DNSLabel": "f2a34f6a4731", - "Status": { - "LastSeen": "2023-03-02T09:21:02.189035775+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-01T19:48:19.817799698+01:00" - }, - "cg05lnblo1hkg2j514p0": { - "ID": "cg05lnblo1hkg2j514p0", - "Key": "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=", - "SetupKey": "", - "IP": "100.64.39.54", - "Meta": { - "Hostname": "expiredhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "expiredhost", - "DNSLabel": "expiredhost", - "Status": { - "LastSeen": "2023-03-02T09:19:57.276717255+01:00", - "Connected": false, - "LoginExpired": true - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK", - "SSHEnabled": false, - "LoginExpirationEnabled": true, - "LastLogin": "2023-03-02T09:14:21.791679181+01:00" - }, - "cg3161rlo1hs9cq94gdg": { - "ID": "cg3161rlo1hs9cq94gdg", - "Key": "mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=", - "SetupKey": "", - "IP": "100.64.117.96", - "Meta": { - "Hostname": "testhost", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "22.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "development", - "UIVersion": "" - }, - "Name": "testhost", - "DNSLabel": "testhost", - "Status": { - "LastSeen": "2023-03-06T18:21:27.252010027+01:00", - "Connected": false, - "LoginExpired": false - }, - "UserID": "edafee4e-63fb-11ec-90d6-0242ac120003", - "SSHKey": "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM", - "SSHEnabled": false, - "LoginExpirationEnabled": false, - "LastLogin": "2023-03-07T09:02:47.442857106+01:00" - } - }, - "Users": { - "edafee4e-63fb-11ec-90d6-0242ac120003": { - "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" - }, - "f4f6d672-63fb-11ec-90d6-0242ac120003": { - "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..ed1133211d28b5b2faf4963e833b2ac521ff7fe0 GIT binary patch literal 163840 zcmeI5Pi!06eaAVXB}%esuCv+9D!WQ(InJ7}C2{yi6df2_N}{xi^~b9}>$Ml@a7aE; zW05oT%utrRNCB?1Mi8J#i{@CMIrY#(QUqwW*PMzTg7#7%*P=c2(%T*iY!3a+U(WE4 zCDoh4t9)sh!+G!b-n{qu{oe2Q-pp{g{r;M4u=t*;H6$aR4!suQxX_=(nl&vQR27*e^;`WNV*8gtsTQ&?Cl4m?);2DP>nH`tMJ%R8*I zt4IyzX6;M*KC9X=+%$I7lr_naRYj~y28o*QhTLGcN)?UK=c-sfc7@#($?(h7CEYO5 zQWLzw_qJBo3tJE3cZ(0~)+KH&ZfX^5#}?b!EdT z;IWHu6_<-!#f`<{cHHWkPGqUGiNxL52ruRbnMi+}qR>D6wQ@AVXL+tY>{*c60^Kwp zwWSzjGuHm)UK8n-vyH+^rIvQemlDtJT#WEM&pm%&nSmu@y3(t1Zz1LwbB33z^KxCS9GJTGYPIxT zS8KB3Q@Cd>LEDgZy=4|QZ8N@9TrTXc?Zl^~raX14Br681Iku;Fnv=DEvDcis z_rMt{&ZES>phv#I^V(p>_jNp*;_@KH;AXurHkoa*VVj1U88;*~nDy(`&gT`s4 zl}<2D>kzg#(wWg}ioOK7=^9pp%u?J}b;ECfR$@&RRXrefz%+%`14$KCbB&;6m1@={ z!wg~psjhldT|A_ok+%y-Ju&(wx8%Obc4PWxDa?4HY6qfJt!fl^Sg;rRv2L)2Iap;1 zO*Ed$ObfKb5ol^^wN)_!Bh%`lli!e()}B-`T3T0mrX0$esx+8F4%$|LRc=qNcbekI zU8SH{m&zTqrhY};kc^g@L4I9lEHKRd0u@zJScO7_&pG(>EgvCT5oVx_#vZq1VkJsV z>h_^TO7-3t-^>Q(MJZ#k=uT<6np z&~0@x&fvTJ1H~6|21GjB%&FgjxlyZWPKl~rsTw<^38~*~$<%3LM$<0x7+wtQH?b=9 z3DGTGsxh%HH)J}tSGtesp23Vd#9qRd(!(SC%GE(3({~7=^LeBkj_~P8uFZK?dpF4Dei(DoDF7L@5=kMmw*7f^t5_$OU zBs>22EX9lhqyR%!YSau_w|uHdQHIRQj3r$&LQmB4^`~&SFD~F~h$pDATB|N;&LZJZ zW>3W4d8jAttOqH6G^N@gukT4eqG<1w=#35iq$+huZUmflKuW6D)hB}x(X>~SXgNxO z%nfjTf|ggGE$&W?$t6%gl*xmrTJQk91357}G{P@V4jOB&35$aJDQvE@L!qx_*oo-x}M!A053iHat2yYObrx zdqQPCIiHy>OWEXfrowX5`Aj-LozBc<)5)Z;S4o#s$y8=qNKPkHaUpj*nY^7&&16%# zY-;{yGDZ6a(~X&m+Mp9dhji9RvZ|ya`%=cLGDTGi^LCYjRb}PlsnvT^x2J?;a%L_w zlblIT-I`jL^aDzpot-Zf3U}^I{qUn7gqM^5`18;>wUdkeE<}HLfdB}A00@8p2!H?x zfB*=900@8p2z&^z=enC{8ae&CM;&<>wZQ3kyqZPiobT zI6VzXWAw}pdhi6@`!7r<=ck2yJe9ninY*2$oAr~~e0pv^m9lUDPo+(f{IoDXSmfkd zlJsSRY3nae-!QQLZd&gj{lLEd|NBtv_vcBC=nDvd00@8p2!H?xfB*=900@8p2!Mb~ zARLa4;rstC4QxRG1V8`;KmY_l00ck)1V8`;K;WxK;I!ZWH}C%skL#h!-@a6&4PGDs z0w4eaAOHd&00JQJWfOR|8NQL-Wzn0H^n;m-a@Lsfw;qJlmaz%TTnP0rWbFi^no~^30N;V}6 zYuW01nWg=e*3Hbr*1FfMTs)b-O^+PU`ofXfj(@=;z#o{exQSQC7$_lc&+>R(<<8b@)_WNxgTxzOb{q z+W2r&{$PJ`L6#S!?Bk{E(bM^Gre#X;iIzN-mOgX|g$=hdSJJL(6IvO7^gN?JL!y|QuBkju089^PEfRn{{1jh&~v+D`MpJERaNqsz>b>CMaw zxm-R!m#1G3Fg*&1^!s%I-_SL$|A%A04#oZ}_Aj)-3j{y_1V8`;KmY_l00ck)1V8`; zK;V@j@Mbv6^*c8jwljK9u?_Lz>{zAG85(>2uY_X%8vC1=^2&4&4FLfV009sH0T2KI z5C8!X009sHfv+%uWSDVhxi=|n7C6hTFGI6oHg=XfD(v;E15 zc_B}a&oCd4AvNXpn;f<6jZWM4=!pc?xm@Z+Px?4b+gdu)Uao28_5Y#RFGF-zs`&N6uGf4cwwkAxw*2?8Jh z0w4eaAOHd&00JNY0w4eauR4K0$o>DJ*IxB*qkSL%0w4eaAOHd&00JNY0w4eaAOHd{ zL%_WLkM;k{5JXuJ009sH0T2KI5C8!X009sH0T6ig31I#I>gxzC1OX5L0T2KI5C8!X z009sH0T2Lzmn9H2E5QB#FN+c7K>!3m00ck)1V8`;KmY_l00cnbl_r4e|F5)q&^QnP z0T2KI5C8!X009sH0T2KI5cpCEL}ULGy3G9|bm{ZhKVABZ@n2o~X#Dob@OW(av*GRV zUxoL>7r0*xeHQvn=x6kgUN^rs!oNGgg=D45j`YWM*d$rY%)z(6{#*tl?u~!(NGVVqKoDa8H#e%?$RIj82wvtWpS%m*eS+W zHx% zETo$zspO`pvHVhbd;v4bx&!sE>gb^b@D zgVh|%(%2s9qZGH&nARuLCu|%y-OitEMGiaL7xghDAKOI0mfV7#IgNLlQOX-*bYQzMyo0M66B_9Ee*0YabML9zg1a@HC0sgfY<@k6jl!;RaDK@Y{@Fs ztV_n8sx<;qUG=ECct|}X2NjTdV)RXJ$$gXU#`MinnDIo_4n(P1)hM>HU@!Dz-CzxK z1j-ctXpoed7HEef(A3mwt6~I(o7F`pzac5DJ*i@}w65|@Ig~Y3X)uL+s;vO4+@4(T zG{uj*N`bd7l{;un{ffFF87(t|T(QnrVC?z@DypKe3VC(ki1+7P17Ea)!ax~~J#NXw zN|c(^?L&!_?0FB6BP=pOJ?QG>g2P zUA5DHGu(B|+Lud_Lpv5LYSlB;PNr-e2lVUZ)R<~z1bo~Xe3vysdxPMqeOF@iMPbUW zU9dJ|ha(aG-K$)C!t+7iaQBDKPGY3jDRnY^!Pa(Aovbd`_w2hDBK*~>+;iS?Rvo#{ zr{SR6>SUb3clig3FOCcdU$&W3zXQ{mR@0otRJ&3&c1V+6zuA(h(_D&XO5`!Tz}9ar zRO%CVTe?(ZVqI>?bZoD5AJaX98Fh%gge|3qNBEVigQh&bLkOMEBjs>}Pfv1f&a)~v zyoyz|A<2rnd!g5w+$=}Zwj(!@tjMmQl=$=>7vU!-xu1}$BM0J2xZ9zi^|;AF+i~Zw zC(US`V>H;1;a-*=P9 z!*?gy@yBN=W)vU=7_w5MX2`nbQ$>n0WL9P@>6+$wqL!~eg~NSu0bfHrL5d#6NiZ0ILdsY`Mr;H(2uQoXJ|8H9-DvYJFo zUJ7Jxfa?>qy!vc$ceYC|fdZmT9z@lG2k0HhiP@nMesOZpSaVHS6x>f?T4|4(tL^CM zZ$o_Oz0gHr{4d8Ijolgjx5%H5{&i$>xHR&g7ycuB`NFS<{xt+~JnlbV(rJvVN5>DlyIo2=V5})x75NlRMSQc~KGtKp6SqRkwFFJa zb^Z37#hVW_B>i!Vrn2U2)SF_{@_?5AJ{R0;R>8`thvrE;n)(`aV&NpOG|<3ocrP)v zxcSzL&O$oX3-~TQESn{NOX4Osz3GZUOVzt~vCvmR^n0($VLOs0c-V zD{OksMORxI>DjHdELUlbW|a!8gec+~YDM)H=1o;Iyh7%Vs7XrAtLJ0~Wc5uZl~{`D zK9SeydC=jwrHH$EHqSgQyHs2&(z)El!uDceshG&UIl|w)KF9?*_1eZD`1jsYz7gTS zN0ZSjLCZ0l*ct2Y@h1kDOicem=SMb!%e!?4V;TB;_7~?_(rpB_}zdvs#F) zf3k7iJjozd_F3h?RqV-HbjUaZ%W7ZnjBE!TY0H07)-CC1GBNwcizYQL717oi9q7%A ze`Zf=_FnG}7d`V2JK-C>I_r_NHRkxi;7Q#xU8wlvJ6EjO@}p~(a_tK3_pb#N>m2J0 zYW1Hzy+s!#%oB!=e{iRh=6Q<$hm)sv+h4-9pZS_m|7$0IuuOzs3!WRVmjxi7l{2A}`I+vhJ^_6tPg1-26AA8_b0IO{or4u$5B4{h8MFz%5%ifp_!{8&lmB}(2=uM%DxL2*FOCV{yEeja zCI$|_02DJa_z(Ks_Qk6f!uI=?;n?eKcljOM`%VyHFw;bM(yHF|y91Z;iS|%D!oTwl z*Iu+*BP(<|s!o?N$W!Zqoom6i1oB*MdJ}zrE-fU?a|kVWg4i3p6S)81J%fQQ2!H?x zfB*=900@8p2!H?xfB*=bUjlgk-}%)odIka@00JNY0w4eaAOHd&00JNY0xkhu|944X z3j!bj0w4eaAOHd&00JNY0w4ea=a&H1|L0e?=otur00@8p2!H?xfB*=900@8p2)G1r z|G!HDTMz&N5C8!X009sH0T2KI5C8!XIKKpN{r~*x7Ci$25C8!X009sH0T2KI5C8!X z00EZ(*8eUEY(W47KmY_l00ck)1V8`;KmY_l;QSK6{r~4zx9AxNfB*=900@8p2!H?x zfB*=900_7QaQ)vUfh`Dt00@8p2!H?xfB*=900@8p2%KL6SpT12-J)k800JNY0w4ea zAOHd&00JNY0wCZL!2SO&32Z?C1V8`;KmY_l00ck)1V8`;K;Zlm!1e$0t6TI81V8`; zKmY_l00ck)1V8`;KmY_>0$BgMB(Mbm5C8!X009sH0T2KI5C8!X0D<#Mz`XyTi+vHI zKfFKy1V8`;KmY_l00ck)1V8`;KmY{JAAv9z4n^_(|MRC<^aun%00ck)1V8`;KmY_l z00ck)1iA@e{ohRl`yc=UAOHd&00JNY0w4eaAOHd&aQ+D3`~T-pr|1y~fB*=900@8p z2!H?xfB*=900?vw!1}+N2=+k$1V8`;KmY_l00ck)1V8`;K;Zll!2SQ{Pp9Y+2!H?x WfB*=900@8p2!H?xfB*<|6Zl`7K~6CM literal 0 HcmV?d00001 diff --git a/management/server/testdata/storev1.json b/management/server/testdata/storev1.json deleted file mode 100644 index 674b2b87afa..00000000000 --- a/management/server/testdata/storev1.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "Accounts": { - "auth0|61bf82ddeab084006aa1bccd": { - "Id": "auth0|61bf82ddeab084006aa1bccd", - "SetupKeys": { - "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD": { - "Id": "831727121", - "Key": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:09:45.926075752+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926075752+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:45.763424077+01:00" - }, - "EB51E9EB-A11F-4F6E-8E49-C982891B405A": { - "Id": "1769568301", - "Key": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:09:45.926073628+01:00", - "ExpiresAt": "2022-01-23T16:09:45.926073628+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:13:06.236748538+01:00" - } - }, - "Network": { - "Id": "a443c07a-5765-4a78-97fc-390d9c1d0e49", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=": { - "Key": "oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=", - "SetupKey": "EB51E9EB-A11F-4F6E-8E49-C982891B405A", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:13:11.244342541+01:00", - "Connected": false - } - }, - "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=": { - "Key": "xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=", - "SetupKey": "1B2B50B0-B3E8-4B0C-A426-525EDB8481BD", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:49.089339333+01:00", - "Connected": false - } - } - } - }, - "google-oauth2|103201118415301331038": { - "Id": "google-oauth2|103201118415301331038", - "SetupKeys": { - "5AFB60DB-61F2-4251-8E11-494847EE88E9": { - "Id": "2485964613", - "Key": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "Name": "Default key", - "Type": "reusable", - "CreatedAt": "2021-12-24T16:10:02.238476+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238476+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:12:05.994307717+01:00" - }, - "A72E4DC2-00DE-4542-8A24-62945438104E": { - "Id": "3504804807", - "Key": "A72E4DC2-00DE-4542-8A24-62945438104E", - "Name": "One-off key", - "Type": "one-off", - "CreatedAt": "2021-12-24T16:10:02.238478209+01:00", - "ExpiresAt": "2022-01-23T16:10:02.238478209+01:00", - "Revoked": false, - "UsedTimes": 1, - "LastUsed": "2021-12-24T16:11:27.015741738+01:00" - } - }, - "Network": { - "Id": "b6d0b152-364e-40c1-a8a1-fa7bcac2267f", - "Net": { - "IP": "100.64.0.0", - "Mask": "/8AAAA==" - }, - "Dns": "" - }, - "Peers": { - "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=": { - "Key": "6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=", - "SetupKey": "5AFB60DB-61F2-4251-8E11-494847EE88E9", - "IP": "100.64.0.2", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:12:05.994305438+01:00", - "Connected": false - } - }, - "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=": { - "Key": "Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=", - "SetupKey": "A72E4DC2-00DE-4542-8A24-62945438104E", - "IP": "100.64.0.1", - "Meta": { - "Hostname": "braginini", - "GoOS": "linux", - "Kernel": "Linux", - "Core": "21.04", - "Platform": "x86_64", - "OS": "Ubuntu", - "WtVersion": "" - }, - "Name": "braginini", - "Status": { - "LastSeen": "2021-12-24T16:11:27.015739803+01:00", - "Connected": false - } - } - } - } - } -} \ No newline at end of file diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite new file mode 100644 index 0000000000000000000000000000000000000000..9a376698e4d226fc08fa68c12fb9bb4cf50375cd GIT binary patch literal 163840 zcmeI5U2GfKb;miPBucg=uCv)poJ29DU9U~Zv^o5y17WKV+N-Q3c_lkpV-XC8nW4`*J9@$Z(o!v!l_ zZtRO4rs=98wM402XSyyL>Mm1sv8yvplpBL{FD`FZ7VlRQtLrP3_Y&3PtEnb7))Upi zl~u1*oyMxkr(4r8zFg+Mzi3Fc78_h(aBAG?)BVMzyOp7BC9Y9#s|kbc8`r&Nol`|> zGdF8T(sx+HhT*2M{f?|jhO8=LLo!Izd^hAayI!qpj6OHS+JP(VrbtFmt|jS)iI%$H z6~4E*y0*CaF!4_1q20UW_2rHAt^1pctLyg@)$LtN7rNMyL_-&w)x_RxFITn_*3fh!OM^`&--(a&VtJT}48|!6gVSHBO~v>;&mE3=7G(B7H_b=w zDhAn%bw1s1BE539QCO+eQ@?yQ`E22QjOTgof|4 zJ;pe#CTwq{Go#xP0||1+HLNz7rMRQ&M$iDQ#HK2$dPr=SX$otFlIp7F8bQ@6)oDqF z8N@i zXg<}L7V3v1(9zUJw{C=HrZq$-zbz@C-rc2pZ6Vvg&C{)=^X7&+5Wx5y#dRXgK1 zW2bM{fn16l+JRVC8=j&1nX+*Z(yy1(WU7@Ba)+mIbJh&)kEkbia}!6oIlFcWqcn>q z_r_!VTeIBZCC>+W(>)kE{lr+mQ}Qzd!Pa(AepZj`d-mpNjGvw5KIScF<;(e>#=>sP z&p3te3Qm+j$QcspY%`~Eht@{jj=3ai^rUKRj~1jsvn5lfg&8fo$YXdhaL~f4Iv_-M zb*ah3mfV)9ZLjnI(>;Y5`NW>Xma=2x{GHihAv4f~;C~*kMPqz+hCAdut8yc;*ihS& zthlG6{nq4WIg+*=xrt;&b_LbsNB6iGKQqJqkX#)(5Ld!Ij)bkpO%B_RyBI!hdF-_5 z_+$G_BoD%xs^+b%X$|M(bTqjwk=t+@4Y)%$xX2Ci@A9_nJAXHaj;`N#lgPvOCfV^P zU@2x4AO#q*(xh(4y5&h@!nyqCYkalB%~PxgB!WAt|ZVQlAV%M9W@HqU|UJGB>~t z2wGl!wz#`6CYL|~Q6>+fYT*X@Cvx)U$T+_|Gir>MEljzqpVJw@*$)Bg~O zKcqLjKmY_l00ck)1V8`;KmY_l00cnb^FZJ*Ir7@p)Y{b4#Di~7O!+MS3kP%zgO&ch*1o@kS(a_y+gd>VUqQeN2C^Ak?-?nMQ+2wR9<$P8TFe zsMYI@LERKeLaFFk56<%@-UDkl|7sd90E5C8!X009sH0T2KI5CDNQPT*-|G(Nla;PtO=ZEa0VZe5(1$mY|zlKHQAg8Ph% z#Y`o)vYbh!(<_x!E}zS!N{jTobs09Z{jZC(bE7HT-<@vNgPi2z`RVt)1*`q{GP)krmxRoj1NDKL5PAFzec0=!f z=Wsu2A(&ZPTeA^la;1E^kShq;6ClVh-d-xCSC&!*;dUmK%jAVrsUireTscQ0s#Hp) zN_mBCOWl@XBG4H9YyF~3LW!Q9Kl<)*WhpOI%9W+mq9EK(*RNE|7^8ilo_q=D*9uFVpaS<1b}}Vx}l$jv2kM zlv&EBm(r=FY^9XSEv1)Ji@8i8mCxiWD@&zZNmyDr&g=u-&KL8UL0TUsGI;QzS^^?M z#-v;w=GH}PM`e;NO;Gw$dxJP-f@5C8!X009sH0T2KI5C8!X z_#-AT9i8R;D|VyiseknjFHd?`IPCrZuKE4{e~$kozWYZE0bK_H5C8!X z009sH0T2KI5C8!X0D;c~ffIgdfAoZM(YfdtPK})Z3;KW; z2!H?xfB*=900@A<871&+Wb`WAVh^rdroT&j>CzOvBJ}@aD+#)AQo~DYV=1LyC2Lgx72Qfezlu^zx$r^ zcSFsy553af2YNm|{FJ))@OZXt{dTzdlsdfx^C|VTkf0wHxBksr?^JfckB-yNdDC;g z=8ulg%{_}mhbZB4j}i*Ik8162Jr;J_-TJq_ySMb7n&0TwALch6?#(~i(wBrs5ANT6 zoGA+{PfntQ(^PkoXWM-xWd7hv=uv_H5WD%^0{cPc=TO1Z4@a)D)sxJF{nmc@W_hJF zFV~b!d3ANEq`ZA+ZU4^3clW;~6t&XU&b_yra!z`xFP&|r+~V)F0Gaxe_Kv3AJki0-6wZ8=jXS& z+Wt2W*0P;4X&j8SOOmf0=)3@*c zFGS)>d?Eg?@xM8v!4AU$0T2KI5C8!X009sH0T2KI5CDPSpFldwxYJzC95oA^=E_^t zEO45OMI$$(Z1OZ$H?aTz{nd|>AOHd&00JNY0w4eaAOHd&00Ms^1kTLc|3_YV@qHw{S00@8p2!H?xfB*=900@8p2!OzgCV>0@ zFPbpw2LTWO0T2KI5C8!X009sH0T2Lzmrnru|Cg@=^Z*1v00ck)1V8`;KmY_l00ck) z1YR@&Jpccq38Q`x009sH0T2KI5C8!X009sH0T6ik1aSZVxoLM`@bc@{}g)B8_WLt}_c+HKBCdOq1)^tBR>J)?!wJf%Krb7}UQx<(=8?YP8l; z>$|4ryjm@7w$!Gq1k~x-LRj;%u6NDCrsXA8Dz_IO+`XTelREO;^{T8Gtm#;m#0GZlby(P(tn>5z=F}_a`WPzR zV@`J`$N8%~Cwmra7u{tvVv;^LtZotrS866Ce-MjV7WbpJqcn0SwnR4+mh1Vmg+`V>nYEaJy}zgHdDx_+6u7BZObjcD?!{13cM|; z=A$+BE9!=1bj=KM#X4i5xf>LytBS(vjJ_asuX z=RHjKW=(a+mh_#VCFyj^6f>d*+Yhv92ciKY=D1Gilvrnnk%tU*i@cm&wKIM*-1%l5 z$fd}k9f)qq5a~gMOIn(W!tC&Vls>b$c(Hk^dGId%@(MpLth8Nfdt%a%s;%--$noMlT zZJFBkN)IsIQ<#xY>^W>HJ2uYWnH{#|2{a-2pT}#_7@wWt4mr=N-0&(k)V3rm?&)a1 zHMv=iq-{rTB3Y4LK{ff&Jub%2%y2&>S4R%Sm2i(EVe4^|!?xqDUyobSI&C`s*gg}< zgRrKmd5d9Mm^nEeO>Rr%Hk?KS?$8Y`a)bQ4ye<3A-_4<;>-XIx^6Mcobhn#gtN@}&#C&LiYT2_;2%S(aG4R8a3mRFxG z?yh#pB~U<=$%Cj`xPktOoV+!;h(7Ruk5|L z`+oSP%Q1dK;10{)#5Fyrwc()QJII~pDZ2G%&h<~|ej#^aSn<)^fcJ~^go~#i&ct{@ z;6A!yP4<5I{**-60rU%mU4cI+Z>|M)$_=L1HMwJMi-Xw?7;k#ofLsa34FsmY?H@ex zLAEPt^&Q8L=>2BaZ#2EzXxBFE1&F!z^A*>(tqu(O=-trr*T?y93&Sk4pRFkLr_FAE zEyfowafg$hanXe)!(2CdTXZ|&a=#JvN!oG2PVO<5XIH)w<1by}KDLd;ma*Sw!baf~ zPXw=*_?>a}==fpxv`2}%jCDk-qJW~QNMO}FzcUwh^ItvcDx_1rU?0E! z!KE0#c!fKh@rN(y4S!0t)C6;1(FXZ_n4?7&U6md7N z=9!meS1NZabS-yzacg;TrIIXud7OXe$}kt?)N31q5Ip;;@}(I64O)y|3|o%bMSrfl z?Oz&VGGXa}y--Gv`+0idi`K&Q*{*LaB;~fi4=@t5lH;7vX)VOoKiRl$UStsKJFLFz zD)x9SI%J%QW%Vz7Mb<}0+6rEjbxS&$Ox}F$s6~xSMYMHB4So6ipWBO??N@ulMc?|v zPWV#4&iW*6jX8cWd{OsQH!7aKanXt`-+$dwu3e%1{)Mn&owiP)*5K9C>vU7XykO|~ z2X{GXUZ?mBE}q(be-77v>U&0m@14A``#JB#$9b_l)R_?Z0k^m4lF6X-S0WLv_y@Z= zG@#+IYv{S67K|K!T=#PxU5W8GZ*qrGdr9AW+0|VFdDmgDh38jazKO|Nclnp6C0Z7l z7ao&`<8Q?HJ4x;^83+PF#~2C)!Hlc@&L!+p10~(C5G+q_Fz70uM(1LDGRZx2ZX|}h zbGYIw;qitu!!{8#f&t5iKnEOu^7}`F(AawJ(goT%IwpMM^>Kb9IduAkpqP;%*yx*w z=VvX1hwoa3WAC@!?RWU>jWEJ+rit*lUA^mfhi>DOha-s?fAuPNxNP-CR_Jn6i*93( zr`AI!uZQ~*%5$~pE%bxAw2?5cA#~jZVt?>X;`x8~3I>iK00JNY0w4eaAOHd&00JNY z0w8d93E=nt&aQ4TG7ta(5C8!X009sH0T2KI5C8!Xa0%f4ze@r~5C8!X009sH0T2KI z5C8!X009s{{QUi79#@z5C8!X009sH0T2KI5C8!X00EZ( zp8t1A;0OXB00JNY0w4eaAOHd&00JNY0%w;1?*E@%-C|@Q00JNY0w4eaAOHd&00JNY z0wCZL!2aJQfg=cj00@8p2!H?xfB*=900@8p2%KF4=JWqt{C5%hhZhKd00@8p2!H?x zfB*=900@8p2!O!ZBM{}Hktuxt|LiFiBLV>s009sH0T2KI5C8!X009sHfnEaG|MwEX zIS7CN2!H?xfB*=900@8p2!H?xoIL{g{{Pw2DMkbWAOHd&00JNY0w4eaAOHd&00O-P zu>bERf^!f60T2KI5C8!X009sH0T2KI5IB1T@cjST( Date: Fri, 4 Oct 2024 17:17:01 +0300 Subject: [PATCH 24/81] [management] Refactor User JWT group sync (#2690) * Refactor GetAccountIDByUserOrAccountID Signed-off-by: bcmmbaga * sync user jwt group changes Signed-off-by: bcmmbaga * propagate jwt group changes to peers Signed-off-by: bcmmbaga * fix no jwt groups synced Signed-off-by: bcmmbaga * fix tests and lint Signed-off-by: bcmmbaga * Move the account peer update outside the transaction Signed-off-by: bcmmbaga * move updateUserPeersInGroups to account manager Signed-off-by: bcmmbaga * move event store outside of transaction Signed-off-by: bcmmbaga * get user with update lock Signed-off-by: bcmmbaga * Run jwt sync in transaction Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 284 ++++++++++++------ management/server/account_test.go | 169 +++++++---- management/server/mock_server/account_mock.go | 12 +- management/server/sql_store.go | 118 +++++++- management/server/sql_store_test.go | 30 ++ management/server/store.go | 5 +- management/server/user.go | 72 ++++- management/server/user_test.go | 5 +- 8 files changed, 519 insertions(+), 176 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index d5e8c8cf8b1..da320385279 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,7 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) - GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -478,12 +478,12 @@ func (a *Account) GetPeerNetworkMap( } nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, + Peers: peersToConnect, + Network: a.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, RoutesFirewallRules: routesFirewallRules, } @@ -843,55 +843,54 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. -// Returns true if there are changes in the JWT group membership. -func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { - user, ok := a.Users[userID] - if !ok { - return false - } - +// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. +// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, +// newly groups to create and an error if any occurred. +func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { existedGroupsByName := make(map[string]*nbgroup.Group) - for _, group := range a.Groups { + for _, group := range groups { existedGroupsByName[group.Name] = group } - newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) - groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) + + groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { - return false + return false, nil, nil, nil } + newGroupsToCreate := make([]*nbgroup.Group, 0) + var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { group = &nbgroup.Group{ - ID: xid.New().String(), - Name: name, - Issued: nbgroup.GroupIssuedJWT, + ID: xid.New().String(), + AccountID: user.AccountID, + Name: name, + Issued: nbgroup.GroupIssuedJWT, } - a.Groups[group.ID] = group + newGroupsToCreate = append(newGroupsToCreate, group) } if group.Issued == nbgroup.GroupIssuedJWT { - newAutoGroups = append(newAutoGroups, group.ID) + newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } } for name, id := range jwtGroupsMap { if !slices.Contains(groupsToRemove, name) { - newAutoGroups = append(newAutoGroups, id) + newUserAutoGroups = append(newUserAutoGroups, id) continue } modified = true } - user.AutoGroups = newAutoGroups - return modified + return modified, newUserAutoGroups, newGroupsToCreate, nil } // UserGroupsAddToPeers adds groups to all peers of user @@ -1262,37 +1261,31 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. -// If an accountID is provided, it checks if the account exists and returns it. -// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// GetAccountIDByUserID retrieves the account ID based on the userID provided. +// If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. // Returns the account ID or an error if none is found or created. -func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { - if accountID != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) - } - return accountID, nil +func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) { + if userID == "" { + return "", status.Errorf(status.NotFound, "no valid userID provided") } - if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) - if err != nil { - return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) - } + accountID, err := am.Store.GetAccountIDByUserID(userID) + if err != nil { + if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { - return "", err + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + return account.Id, nil } - - return account.Id, nil + return "", err } - - return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") + return accountID, nil } func isNil(i idp.Manager) bool { @@ -1796,6 +1789,10 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } + if user.AccountID != accountID { + return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID) + } + if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { @@ -1803,7 +1800,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai } } - if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, claims); err != nil { return "", "", err } @@ -1812,7 +1809,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err @@ -1823,69 +1820,136 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + log.WithContext(ctx).Debugf("JWT groups are enabled but no claim name is set") return nil } - // TODO: Remove GetAccount after refactoring account peer's update - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer func() { + if unlockPeer != nil { + unlockPeer() + } + }() - // Update the account if group membership changes - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) + var addNewGroups []string + var removeOldGroups []string + var hasChanges bool + var user *User + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user: %w", err) + } - if settings.GroupsPropagationEnabled { - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() + groups, err := am.Store.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) } - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) + changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames) + if err != nil { + return fmt.Errorf("error getting JWT groups changes: %w", err) + } + + hasChanges = changed + // skip update if no changes + if !changed { return nil } + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + addNewGroups = difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + + user.AutoGroups = updatedAutoGroups + if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + return fmt.Errorf("error saving user: %w", err) + } + // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - } - - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + groups, err = transaction.GetAccountGroups(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account groups: %w", err) + } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + if err != nil { + return fmt.Errorf("error getting user peers: %w", err) + } + + updatedGroups, err := am.updateUserPeersInGroups(groupsMap, peers, addNewGroups, removeOldGroups) + if err != nil { + return fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + return fmt.Errorf("error incrementing network serial: %w", err) } } + unlockPeer() + unlockPeer = nil + + return nil + }) + if err != nil { + return err + } + + if !hasChanges { + return nil + } - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) + for _, g := range addNewGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, + } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta) + } + } + + for _, g := range removeOldGroups { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + if err != nil { + log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) + } else { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName, } + am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta) } } + if settings.GroupsPropagationEnabled { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("error getting account: %w", err) + } + + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + return nil } @@ -1916,7 +1980,17 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + if claims.AccountId != "" { + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) + } + return claims.AccountId, nil + } + return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } else if claims.AccountId != "" { userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { @@ -2229,7 +2303,11 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac routes := make(map[route.ID]*route.Route) setupKeys := map[string]*SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - users[userID] = NewOwnerUser(userID) + + owner := NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + dnsSettings := DNSSettings{ DisabledManagementGroups: make([]string, 0), } @@ -2297,12 +2375,17 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID + allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + for _, group := range allGroups { + allGroupsMap[group.ID] = group + } + for _, id := range autoGroups { - if group, ok := allGroups[id]; ok { + if group, ok := allGroupsMap[id]; ok { if group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { @@ -2310,5 +2393,6 @@ func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([ } } } + return newAutoGroups, jwtAutoGroups } diff --git a/management/server/account_test.go b/management/server/account_test.go index 198775bc33e..c417e4bc89b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -633,7 +633,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") initAccount, err := manager.Store.GetAccount(context.Background(), accountID) @@ -671,17 +671,16 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - initAccount := newAccountWithId(context.Background(), "", userId, domain) + _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID := initAccount.Id - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ @@ -885,7 +884,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } } -func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { +func TestAccountManager_GetAccountByUserID(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -894,7 +893,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "") if err != nil { t.Fatal(err) } @@ -903,14 +902,13 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { return } - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) - } + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + assert.NoError(t, err) + assert.True(t, exists, "expected to get existing account after creation using userid") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), "", "") if err == nil { - t.Errorf("expected an error when user and account IDs are empty") + t.Errorf("expected an error when user ID is empty") } } @@ -1669,7 +1667,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) @@ -1684,7 +1682,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1696,7 +1694,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1742,7 +1740,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1770,7 +1768,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1790,7 +1788,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1802,7 +1800,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to get the account") account, err := manager.Store.GetAccount(context.Background(), accountID) @@ -1850,7 +1848,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ @@ -1861,9 +1859,6 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") - require.NoError(t, err, "unable to get account by ID") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") @@ -2199,8 +2194,12 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } func TestAccount_SetJWTGroups(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + // create a new account account := &Account{ + Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2211,62 +2210,120 @@ func TestAccount_SetJWTGroups(t *testing.T) { Groups: map[string]*group.Group{ "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true}, + Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*User{ - "user1": {Id: "user1"}, - "user2": {Id: "user2"}, + "user1": {Id: "user1", AccountID: "accountID"}, + "user2": {Id: "user2", AccountID: "accountID"}, }, } + assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("empty jwt groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.False(t, updated, "account should not be updated") - assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) t.Run("jwt match existing api group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err := manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} + assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) - updated := account.SetJWTGroups("user1", []string{"group1"}) - assert.False(t, updated, "account should not be updated") - assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) - assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1) + + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + assert.NoError(t, err, "unable to get group") + assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("existed group not update", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group2"}) - assert.False(t, updated, "account should not be updated") - assert.Len(t, account.Groups, 2, "groups count should not be changed") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group2"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) t.Run("add new group", func(t *testing.T) { - updated := account.SetJWTGroups("user2", []string{"group1", "group3"}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Groups, 3, "new group should be added") - assert.Len(t, account.Users["user2"].AutoGroups, 1, "new group should be added") - assert.Contains(t, account.Groups, account.Users["user2"].AutoGroups[0], "groups must contain group3 from user groups") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{"groups": []interface{}{"group1", "group3"}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + groups, err := manager.Store.GetAccountGroups(context.Background(), "accountID") + assert.NoError(t, err) + assert.Len(t, groups, 3, "new group3 should be added") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "new group should be added") }) t.Run("remove all JWT groups", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{}) - assert.True(t, updated, "account should be updated") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, account.Users["user1"].AutoGroups, "group1", " group1 should still be present") + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{}}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") + assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") }) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b399be82288..b6283a7e69a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,7 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -194,14 +194,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { - if am.GetAccountIDByUserOrAccountIdFunc != nil { - return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.GetAccountIDByUserIdFunc(ctx, userId, domain) } return "", status.Errorf( codes.Unimplemented, - "method GetAccountIDByUserOrAccountID is not implemented", + "method GetAccountIDByUserID is not implemented", ) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index cce748a0f84..9e1ab27dcc3 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "slices" "strings" "sync" "time" @@ -378,15 +379,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { Create(&usersToSave).Error } +// SaveUser saves the given user to the database. +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) + } + return nil +} + // SaveGroups saves the given list of groups to the database. -// It updates existing groups if a conflict occurs. -func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { - groupsToSave := make([]nbgroup.Group, 0, len(groups)) - for _, group := range groups { - group.AccountID = accountID - groupsToSave = append(groupsToSave, *group) +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + if len(groups) == 0 { + return nil + } + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } - return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error + return nil } // DeleteHashedPAT2TokenIDIndex is noop in SqlStore @@ -1021,6 +1033,89 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddUserPeersToGroups adds the user's peers to specified groups in database. +func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for _, pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. +func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { + if len(groupIDs) == 0 { + return nil + } + + var userPeerIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). + Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) + if result.Error != nil { + return status.Errorf(status.Internal, "issue finding user peers") + } + + groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) + for _, gid := range groupIDs { + group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return err + } + + if group.Name == "All" { + continue + } + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if !slices.Contains(userPeerIDs, pid) { + update = append(update, pid) + } + } + + group.Peers = update + groupsToUpdate = append(groupsToUpdate, group) + } + + return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) +} + +// GetUserPeers retrieves peers for a user. +func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account") @@ -1127,6 +1222,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren return &group, nil } +// SaveGroup saves a group to the store. +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + } + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index dc07849d9bf..4eed09c69b6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1185,3 +1185,33 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } + +func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { + store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + group := &nbgroup.Group{ + ID: "group-id", + AccountID: "account-id", + Name: "group-name", + Issued: "api", + Peers: nil, + } + err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { + err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + t.Fatal("failed to save group") + return err + } + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + if err != nil { + t.Fatal("failed to get group") + return err + } + t.Logf("group: %v", group) + return nil + }) + assert.NoError(t, err) +} diff --git a/management/server/store.go b/management/server/store.go index 041c936ae56..50bc6afdfd2 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -60,6 +60,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) SaveUsers(accountID string, users map[string]*User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error @@ -68,7 +69,8 @@ type Store interface { GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -82,6 +84,7 @@ type Store interface { AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error diff --git a/management/server/user.go b/management/server/user.go index 6d01561c6cc..38a8ac0c401 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,14 +8,14 @@ import ( "time" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -1254,6 +1254,74 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil } +// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { + + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return + } + + userPeerIDMap := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + userPeerIDMap[peer.ID] = struct{}{} + } + + for _, gid := range groupsToAdd { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + addUserPeersToGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + for _, gid := range groupsToRemove { + group, ok := accountGroups[gid] + if !ok { + return nil, errors.New("group not found") + } + removeUserPeersFromGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + return groupsToUpdate, nil +} + +// addUserPeersToGroup adds the user's peers to the group. +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + groupPeers := make(map[string]struct{}, len(group.Peers)) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } +} + +// removeUserPeersFromGroup removes user's peers from the group. +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + // skip removing peers from group All + if group.Name == "All" { + return + } + + updatedPeers := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if _, found := userPeerIDs[pid]; !found { + updatedPeers = append(updatedPeers, pid) + } + } + + group.Peers = updatedPeers +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID { diff --git a/management/server/user_test.go b/management/server/user_test.go index ec0a1069576..1a5704551bc 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -813,10 +813,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") - assert.NoError(t, err) - - acc, err := am.Store.GetAccount(context.Background(), accID) + acc, err := am.Store.GetAccount(context.Background(), account.Id) assert.NoError(t, err) for _, id := range tc.expectedDeleted { From 8bf729c7b4e8d3f7a7a88c9dc2ac60b63c803a8c Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 4 Oct 2024 18:09:40 +0300 Subject: [PATCH 25/81] [management] Add AccountExists to AccountManager (#2694) * Add AccountExists method to account manager interface Signed-off-by: bcmmbaga * remove unused code Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 6 ++ management/server/mock_server/account_mock.go | 13 ++- management/server/sql_store.go | 79 ------------------- 3 files changed, 17 insertions(+), 81 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index da320385279..a9781b385a8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,6 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error @@ -1261,6 +1262,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } +// AccountExists checks if an account exists. +func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) +} + // GetAccountIDByUserID retrieves the account ID based on the userID provided. // If user does have an account, it returns the user's account ID. // If the user doesn't have an account, it creates one using the provided domain. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index b6283a7e69a..ec29222a460 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -27,6 +27,7 @@ type MockAccountManager struct { CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) @@ -58,7 +59,7 @@ type MockAccountManager struct { UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -194,6 +195,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } +// AccountExists mock implementation of AccountExists from server.AccountManager interface +func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { + if am.GetAccountIDByUserIdFunc != nil { + return am.AccountExistsFunc(ctx, accountID) + } + return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") +} + // GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) { if am.GetAccountIDByUserIdFunc != nil { @@ -444,7 +453,7 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9e1ab27dcc3..d056015d823 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,7 +10,6 @@ import ( "path/filepath" "runtime" "runtime/debug" - "slices" "strings" "sync" "time" @@ -1033,84 +1032,6 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } -// AddUserPeersToGroups adds the user's peers to specified groups in database. -func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { - if len(groupIDs) == 0 { - return nil - } - - var userPeerIDs []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). - Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) - if result.Error != nil { - return status.Errorf(status.Internal, "issue finding user peers") - } - - groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, gid := range groupIDs { - group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return err - } - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for _, pid := range userPeerIDs { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - groupsToUpdate = append(groupsToUpdate, group) - } - - return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) -} - -// RemoveUserPeersFromGroups removes the user's peers from specified groups in database. -func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error { - if len(groupIDs) == 0 { - return nil - } - - var userPeerIDs []string - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id"). - Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs) - if result.Error != nil { - return status.Errorf(status.Internal, "issue finding user peers") - } - - groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, gid := range groupIDs { - group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) - if err != nil { - return err - } - - if group.Name == "All" { - continue - } - - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - if !slices.Contains(userPeerIDs, pid) { - update = append(update, pid) - } - } - - group.Peers = update - groupsToUpdate = append(groupsToUpdate, group) - } - - return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) -} - // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) From 5897a48e299d5553b6b375d2bc2b4df3e2dc24f1 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 4 Oct 2024 18:55:25 +0300 Subject: [PATCH 26/81] fix wrong reference (#2695) Signed-off-by: bcmmbaga --- management/server/mock_server/account_mock.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ec29222a460..74557e2275c 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -197,7 +197,7 @@ func (am *MockAccountManager) CreateSetupKey( // AccountExists mock implementation of AccountExists from server.AccountManager interface func (am *MockAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - if am.GetAccountIDByUserIdFunc != nil { + if am.AccountExistsFunc != nil { return am.AccountExistsFunc(ctx, accountID) } return false, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented") From f603cd92027f76c91b7d14ebb64ae6f3aa328639 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Fri, 4 Oct 2024 11:15:16 -0600 Subject: [PATCH 27/81] [client] Check wginterface instead of engine ctx (#2676) Moving code to ensure wgInterface is gone right after context is cancelled/stop in the off chance that on next retry the backoff operation is permanently cancelled and interface is abandoned without destroying. --- client/internal/connect.go | 15 +++++++++------ client/internal/engine.go | 26 ++++++++++++++------------ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index c77f95603d0..74dc1f1b56d 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -269,12 +269,6 @@ func (c *ConnectClient) run( checks := loginResp.GetChecks() c.engineMutex.Lock() - if c.engine != nil && c.engine.ctx.Err() != nil { - log.Info("Stopping Netbird Engine") - if err := c.engine.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) - } - } c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) c.engineMutex.Unlock() @@ -294,6 +288,15 @@ func (c *ConnectClient) run( } <-engineCtx.Done() + c.engineMutex.Lock() + if c.engine != nil && c.engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) + if err := c.engine.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + c.engine = nil + } + c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() diff --git a/client/internal/engine.go b/client/internal/engine.go index c51901a225d..eac8ec098f6 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -251,6 +251,13 @@ func (e *Engine) Stop() error { } log.Info("Network monitor: stopped") + // stop/restore DNS first so dbus and friends don't complain because of a missing interface + e.stopDNSServer() + + if e.routeManager != nil { + e.routeManager.Stop() + } + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) @@ -1116,18 +1123,12 @@ func (e *Engine) close() { } } - // stop/restore DNS first so dbus and friends don't complain because of a missing interface - e.stopDNSServer() - - if e.routeManager != nil { - e.routeManager.Stop() - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) } + e.wgInterface = nil } if !isNil(e.sshServer) { @@ -1395,7 +1396,7 @@ func (e *Engine) startNetworkMonitor() { } // Set a new timer to debounce rapid network changes - debounceTimer = time.AfterFunc(1*time.Second, func() { + debounceTimer = time.AfterFunc(2*time.Second, func() { // This function is called after the debounce period mu.Lock() defer mu.Unlock() @@ -1426,6 +1427,11 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { } func (e *Engine) stopDNSServer() { + if e.dnsServer == nil { + return + } + e.dnsServer.Stop() + e.dnsServer = nil err := fmt.Errorf("DNS server stopped") nsGroupStates := e.statusRecorder.GetDNSStates() for i := range nsGroupStates { @@ -1433,10 +1439,6 @@ func (e *Engine) stopDNSServer() { nsGroupStates[i].Error = err } e.statusRecorder.UpdateDNSStates(nsGroupStates) - if e.dnsServer != nil { - e.dnsServer.Stop() - e.dnsServer = nil - } } // isChecksEqual checks if two slices of checks are equal. From dbec24b52080bd14739e9b0dd8c950e4b0edf0cb Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Sun, 6 Oct 2024 17:01:13 +0200 Subject: [PATCH 28/81] [management] Remove admin check on getAccountByID (#2699) --- management/server/account.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index a9781b385a8..6ee0015f86f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,6 +20,11 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -36,10 +41,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" ) const ( @@ -1764,7 +1765,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + if user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } From 2c1f5e46d5928a21458749251c4ee5eb96575239 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 7 Oct 2024 19:06:26 +0300 Subject: [PATCH 29/81] [management] Validate peer ownership during login (#2704) * check peer ownership in login Signed-off-by: bcmmbaga * update error message Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/peer.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index da958673414..a7d4f3b06aa 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -693,6 +693,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updateRemotePeers := false if login.UserID != "" { + if peer.UserID != login.UserID { + log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") + } + changed, err := am.handleUserPeer(ctx, peer, settings) if err != nil { return nil, nil, nil, err From 44e81073832c2d1167ef6ec284c8bb9f4c5bc3d8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:21:11 +0200 Subject: [PATCH 30/81] [client] Limit P2P attempts and restart on specific events (#2657) --- client/internal/peer/conn.go | 60 ++++--- client/internal/peer/conn_monitor.go | 212 +++++++++++++++++++++++++ client/internal/peer/stdnet.go | 4 +- client/internal/peer/stdnet_android.go | 4 +- client/internal/peer/worker_ice.go | 63 ++++---- 5 files changed, 285 insertions(+), 58 deletions(-) create mode 100644 client/internal/peer/conn_monitor.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ad84bd7006b..0d4ad2396b3 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -32,6 +32,8 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 + + reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -83,6 +85,7 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager allowedIPsIP string handshaker *Handshaker @@ -108,6 +111,8 @@ type Conn struct { // for reconnection operations iCEDisconnected chan bool relayDisconnected chan bool + connMonitor *ConnMonitor + reconnectCh <-chan struct{} } // NewConn creates a new not opened Conn to the remote peer. @@ -123,21 +128,31 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + wgProxyFactory: wgProxyFactory, + signaler: signaler, + iFaceDiscover: iFaceDiscover, + relayManager: relayManager, + allowedIPsIP: allowedIPsIP.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } + conn.connMonitor, conn.reconnectCh = NewConnMonitor( + signaler, + iFaceDiscover, + config, + conn.relayDisconnected, + conn.iCEDisconnected, + ) + rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, OnDisconnected: conn.onWorkerRelayStateDisconnected, @@ -200,6 +215,8 @@ func (conn *Conn) startHandshakeAndReconnect() { conn.log.Errorf("failed to send initial offer: %v", err) } + go conn.connMonitor.Start(conn.ctx) + if conn.workerRelay.IsController() { conn.reconnectLoopWithRetry() } else { @@ -309,12 +326,14 @@ func (conn *Conn) reconnectLoopWithRetry() { // With it, we can decrease to send necessary offer select { case <-conn.ctx.Done(): + return case <-time.After(3 * time.Second): } ticker := conn.prepareExponentTicker() defer ticker.Stop() time.Sleep(1 * time.Second) + for { select { case t := <-ticker.C: @@ -342,20 +361,11 @@ func (conn *Conn) reconnectLoopWithRetry() { if err != nil { conn.log.Errorf("failed to do handshake: %v", err) } - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, reset reconnect timer") - ticker.Stop() - ticker = conn.prepareExponentTicker() - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, reset reconnect timer") + + case <-conn.reconnectCh: ticker.Stop() ticker = conn.prepareExponentTicker() + case <-conn.ctx.Done(): conn.log.Debugf("context is done, stop reconnect loop") return @@ -366,10 +376,10 @@ func (conn *Conn) reconnectLoopWithRetry() { func (conn *Conn) prepareExponentTicker() *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.01, + RandomizationFactor: 0.1, Multiplier: 2, MaxInterval: conn.config.Timeout, - MaxElapsedTime: 0, + MaxElapsedTime: reconnectMaxElapsedTime, Stop: backoff.Stop, Clock: backoff.SystemClock, }, conn.ctx) diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go new file mode 100644 index 00000000000..75722c99011 --- /dev/null +++ b/client/internal/peer/conn_monitor.go @@ -0,0 +1,212 @@ +package peer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + signalerMonitorPeriod = 5 * time.Second + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ConnMonitor struct { + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + config ConnConfig + relayDisconnected chan bool + iCEDisconnected chan bool + reconnectCh chan struct{} + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { + reconnectCh := make(chan struct{}, 1) + cm := &ConnMonitor{ + signaler: signaler, + iFaceDiscover: iFaceDiscover, + config: config, + relayDisconnected: relayDisconnected, + iCEDisconnected: iCEDisconnected, + reconnectCh: reconnectCh, + } + return cm, reconnectCh +} + +func (cm *ConnMonitor) Start(ctx context.Context) { + signalerReady := make(chan struct{}, 1) + go cm.monitorSignalerReady(ctx, signalerReady) + + localCandidatesChanged := make(chan struct{}, 1) + go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) + + for { + select { + case changed := <-cm.relayDisconnected: + if !changed { + continue + } + log.Debugf("Relay state changed, triggering reconnect") + cm.triggerReconnect() + + case changed := <-cm.iCEDisconnected: + if !changed { + continue + } + log.Debugf("ICE state changed, triggering reconnect") + cm.triggerReconnect() + + case <-signalerReady: + log.Debugf("Signaler became ready, triggering reconnect") + cm.triggerReconnect() + + case <-localCandidatesChanged: + log.Debugf("Local candidates changed, triggering reconnect") + cm.triggerReconnect() + + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { + if cm.signaler == nil { + return + } + + ticker := time.NewTicker(signalerMonitorPeriod) + defer ticker.Stop() + + lastReady := true + for { + select { + case <-ticker.C: + currentReady := cm.signaler.Ready() + if !lastReady && currentReady { + select { + case signalerReady <- struct{}{}: + default: + } + } + lastReady = currentReady + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { + ufrag, pwd, err := generateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { + log.Warnf("Failed to handle candidate tick: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { + log.Debugf("Gathering ICE candidates") + + transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return fmt.Errorf("wait for gathering: %w", ctx.Err()) + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + if changed := cm.updateCandidates(candidates); changed { + select { + case localCandidatesChanged <- struct{}{}: + default: + } + } + + return nil +} + +func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func (cm *ConnMonitor) triggerReconnect() { + select { + case cm.reconnectCh <- struct{}{}: + default: + } +} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index ae31ebbf067..96d211dbc77 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -6,6 +6,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" ) -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) +func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ifaceBlacklist) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index b411405bb95..a39a03b1c83 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -2,6 +2,6 @@ package peer import "github.com/netbirdio/netbird/client/internal/stdnet" -func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) +func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c4e9d195074..c86c1858fdc 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -233,41 +233,16 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := w.newStdNet() + transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) if err != nil { w.log.Errorf("failed to create pion's stdnet: %s", err) } - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: relaySupport, - InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList), - UDPMux: w.config.ICEConfig.UDPMux, - UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: w.config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: w.localUfrag, - LocalPwd: w.localPwd, - } - - if w.config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - w.sentExtraSrflx = false - agent, err := ice.NewAgent(agentConfig) + + agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) if err != nil { - return nil, err + return nil, fmt.Errorf("create agent: %w", err) } err = agent.OnCandidate(w.onICECandidate) @@ -390,6 +365,36 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } +func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), + UDPMux: config.ICEConfig.UDPMux, + UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, + NAT1To1IPs: config.ICEConfig.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.ICEConfig.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ From d4ef84fe6e02e932fdfa24be2bcf2416634f832b Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:33:58 +0200 Subject: [PATCH 31/81] [management] Propagate error in store errors (#2709) --- management/server/sql_store.go | 46 +++++++++++++++---------------- management/server/status/error.go | 8 ++++-- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index d056015d823..67df29ef07d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -431,7 +431,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -444,7 +444,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } if key.AccountID == "" { @@ -462,7 +462,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return token.ID, nil @@ -476,7 +476,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if token.UserID == "" { @@ -560,7 +560,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us @@ -623,7 +623,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if user.AccountID == "" { @@ -640,7 +640,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -658,7 +658,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - return nil, status.Errorf(status.Internal, "issue getting account from store") + return nil, status.NewGetAccountFromStoreError(result.Error) } if peer.AccountID == "" { @@ -676,7 +676,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -689,7 +689,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.Errorf(status.Internal, "issue getting account from store") + return "", status.NewGetAccountFromStoreError(result.Error) } return accountID, nil @@ -702,7 +702,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - return "", status.NewSetupKeyNotFoundError() + return "", status.NewSetupKeyNotFoundError(result.Error) } if accountID == "" { @@ -723,7 +723,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } - return nil, status.Errorf(status.Internal, "issue getting IPs from store") + return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } // Convert the JSON strings to net.IP objects @@ -751,7 +751,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return nil, status.Errorf(status.NotFound, "no peers found for the account") } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } return labels, nil @@ -764,7 +764,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "issue getting network from store") + return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } @@ -776,7 +776,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - return nil, status.Errorf(status.Internal, "issue getting peer from store") + return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } return &peer, nil @@ -788,7 +788,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - return nil, status.Errorf(status.Internal, "issue getting settings from store") + return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil } @@ -956,7 +956,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } - return nil, status.NewSetupKeyNotFoundError() + return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } @@ -988,7 +988,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } - return status.Errorf(status.Internal, "issue finding group 'All'") + return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1000,7 +1000,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group 'All'") + return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil @@ -1014,7 +1014,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } - return status.Errorf(status.Internal, "issue finding group") + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } for _, existingPeerID := range group.Peers { @@ -1026,7 +1026,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { - return status.Errorf(status.Internal, "issue updating group") + return status.Errorf(status.Internal, "issue updating group: %s", err) } return nil @@ -1039,7 +1039,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { - return status.Errorf(status.Internal, "issue adding peer to account") + return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } return nil @@ -1048,7 +1048,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count") + return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil } diff --git a/management/server/status/error.go b/management/server/status/error.go index d7fde35b998..29d185216d8 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -102,8 +102,12 @@ func NewPeerLoginExpiredError() error { } // NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key -func NewSetupKeyNotFoundError() error { - return Errorf(NotFound, "setup key not found") +func NewSetupKeyNotFoundError(err error) error { + return Errorf(NotFound, "setup key not found: %s", err) +} + +func NewGetAccountFromStoreError(err error) error { + return Errorf(Internal, "issue getting account from store: %s", err) } // NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store From b1eda43f4b748ee9940e9a2399f200b60e55f076 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 9 Oct 2024 13:56:25 +0100 Subject: [PATCH 32/81] Add Link to the Lawrence Systems video (#2711) --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index aa3ec41e533..270c9ad8707 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,8 @@ ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +### NetBird on Lawrence Systems (Video) +[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) ### Key features @@ -62,6 +64,7 @@ | | |
  • - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)
| |
  • - \[x] OpenWRT
| | | |
  • - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)
  • | |
    • - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)
    | | | | | |
    • - \[x] Docker
    | + ### Quickstart with NetBird Cloud - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) From b79c1d64cc10871e70c94ee5f47fdf2b4773f5a7 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 9 Oct 2024 20:17:25 +0200 Subject: [PATCH 33/81] [management] Make max open db conns configurable (#2713) --- .github/workflows/golang-test-linux.yml | 2 +- .github/workflows/release.yml | 2 +- management/server/sql_store.go | 11 +++++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 524f35f6f47..d6adcb27aee 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -16,7 +16,7 @@ jobs: matrix: arch: [ '386','amd64' ] store: [ 'sqlite', 'postgres'] - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Install Go uses: actions/setup-go@v5 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7af6d3e4d94..b2e2437e6bb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,7 +20,7 @@ concurrency: jobs: release: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: flags: "" steps: diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 67df29ef07d..fe4dcafdb26 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -10,6 +10,7 @@ import ( "path/filepath" "runtime" "runtime/debug" + "strconv" "strings" "sync" "time" @@ -63,8 +64,14 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr if err != nil { return nil, err } - conns := runtime.NumCPU() - sql.SetMaxOpenConns(conns) // TODO: make it configurable + + conns, err := strconv.Atoi(os.Getenv("NB_SQL_MAX_OPEN_CONNS")) + if err != nil { + conns = runtime.NumCPU() + } + sql.SetMaxOpenConns(conns) + + log.Infof("Set max open db connections to %d", conns) if err := migrate(ctx, db); err != nil { return nil, fmt.Errorf("migrate: %w", err) From 6ce09bca1680c50012cd9467317290808b620808 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 9 Oct 2024 20:46:23 +0200 Subject: [PATCH 34/81] Add support to envsub go management configurations (#2708) This change allows users to reference environment variables using Go template format, like {{ .EnvName }} Moved the previous file test code to file_suite_test.go. --- management/cmd/management.go | 2 +- util/file.go | 53 +++++++ util/file_suite_test.go | 126 +++++++++++++++ util/file_test.go | 292 ++++++++++++++++++++++------------- 4 files changed, 362 insertions(+), 111 deletions(-) create mode 100644 util/file_suite_test.go diff --git a/management/cmd/management.go b/management/cmd/management.go index 78b1a8d631f..719d1a78c1a 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -475,7 +475,7 @@ func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handle func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*server.Config, error) { loadedConfig := &server.Config{} - _, err := util.ReadJson(mgmtConfigPath, loadedConfig) + _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig) if err != nil { return nil, err } diff --git a/util/file.go b/util/file.go index 8355488c98a..ecaecd22260 100644 --- a/util/file.go +++ b/util/file.go @@ -1,11 +1,15 @@ package util import ( + "bytes" "context" "encoding/json" + "fmt" "io" "os" "path/filepath" + "strings" + "text/template" log "github.com/sirupsen/logrus" ) @@ -160,6 +164,55 @@ func ReadJson(file string, res interface{}) (interface{}, error) { return res, nil } +// ReadJsonWithEnvSub reads JSON config file and maps to a provided interface with environment variable substitution +func ReadJsonWithEnvSub(file string, res interface{}) (interface{}, error) { + envVars := getEnvMap() + + f, err := os.Open(file) + if err != nil { + return nil, err + } + defer f.Close() + + bs, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + t, err := template.New("").Parse(string(bs)) + if err != nil { + return nil, fmt.Errorf("error parsing template: %v", err) + } + + var output bytes.Buffer + // Execute the template, substituting environment variables + err = t.Execute(&output, envVars) + if err != nil { + return nil, fmt.Errorf("error executing template: %v", err) + } + + err = json.Unmarshal(output.Bytes(), &res) + if err != nil { + return nil, fmt.Errorf("failed parsing Json file after template was executed, err: %v", err) + } + + return res, nil +} + +// getEnvMap Convert the output of os.Environ() to a map +func getEnvMap() map[string]string { + envMap := make(map[string]string) + + for _, env := range os.Environ() { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = parts[1] + } + } + + return envMap +} + // CopyFileContents copies contents of the given src file to the dst file func CopyFileContents(src, dst string) (err error) { in, err := os.Open(src) diff --git a/util/file_suite_test.go b/util/file_suite_test.go new file mode 100644 index 00000000000..3de7db49bdd --- /dev/null +++ b/util/file_suite_test.go @@ -0,0 +1,126 @@ +package util_test + +import ( + "crypto/md5" + "encoding/hex" + "io" + "os" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/netbirdio/netbird/util" +) + +var _ = Describe("Client", func() { + + var ( + tmpDir string + ) + + type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int + } + + BeforeEach(func() { + var err error + tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("Config", func() { + Context("in JSON format", func() { + It("should be written and read successfully", func() { + + m := make(map[string]string) + m["key1"] = "value1" + m["key2"] = "value2" + + arr := []string{"value1", "value2"} + + written := &TestConfig{ + SomeMap: m, + SomeArray: arr, + SomeField: 99, + } + + err := util.WriteJson(tmpDir+"/testconfig.json", written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) + Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) + Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) + Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) + + }) + }) + }) + + Describe("Copying file contents", func() { + Context("from one file to another", func() { + It("should be successful", func() { + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := util.WriteJson(src, []string{"1", "2", "3"}) + Expect(err).NotTo(HaveOccurred()) + + err = util.CopyFileContents(src, dst) + Expect(err).NotTo(HaveOccurred()) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + Expect(err).NotTo(HaveOccurred()) + + dstFile, err := os.Open(dst) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashSrc, srcFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = io.Copy(hashDst, dstFile) + Expect(err).NotTo(HaveOccurred()) + + err = srcFile.Close() + Expect(err).NotTo(HaveOccurred()) + + err = dstFile.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) + }) + }) + }) + + Describe("Handle config file without full path", func() { + Context("config file handling", func() { + It("should be successful", func() { + written := &TestConfig{ + SomeField: 123, + } + cfgFile := "test_cfg.json" + defer os.Remove(cfgFile) + + err := util.WriteJson(cfgFile, written) + Expect(err).NotTo(HaveOccurred()) + + read, err := util.ReadJson(cfgFile, &TestConfig{}) + Expect(err).NotTo(HaveOccurred()) + Expect(read).NotTo(BeNil()) + }) + }) + }) +}) diff --git a/util/file_test.go b/util/file_test.go index 3de7db49bdd..1330e738e8d 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,126 +1,198 @@ -package util_test +package util import ( - "crypto/md5" - "encoding/hex" - "io" "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" + "reflect" + "strings" + "testing" ) -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int +func TestReadJsonWithEnvSub(t *testing.T) { + type Config struct { + CertFile string `json:"CertFile"` + Credentials string `json:"Credentials"` + NestedOption struct { + URL string `json:"URL"` + } `json:"NestedOption"` } - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) - - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { - - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" + type testCase struct { + name string + envVars map[string]string + jsonTemplate string + expectedResult Config + expectError bool + errorContains string + } - arr := []string{"value1", "value2"} + tests := []testCase{ + { + name: "All environment variables set", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://env.testing.com", + }, + }, + expectError: false, + }, + { + name: "Missing environment variable", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + // "URL" is intentionally missing + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/env_cert.crt", + Credentials: "env_credentials", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "", + }, + }, + expectError: false, + }, + { + name: "Invalid JSON template", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, // Note the missing closing brace in "{{ .CREDENTIALS }" + expectedResult: Config{}, + expectError: true, + errorContains: "unexpected \"}\" in operand", + }, + { + name: "No substitutions", + envVars: map[string]string{ + "CERT_FILE": "/etc/certs/env_cert.crt", + "CREDENTIALS": "env_credentials", + "URL": "https://env.testing.com", + }, + jsonTemplate: `{ + "CertFile": "/etc/certs/cert.crt", + "Credentials": "admnlknflkdasdf", + "NestedOption" : { + "URL": "https://testing.com" + } + }`, + expectedResult: Config{ + CertFile: "/etc/certs/cert.crt", + Credentials: "admnlknflkdasdf", + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: "https://testing.com", + }, + }, + expectError: false, + }, + { + name: "Should fail when Invalid characters in variables", + envVars: map[string]string{ + "CERT_FILE": `"/etc/certs/"cert".crt"`, + "CREDENTIALS": `env_credentia{ls}`, + "URL": `https://env.testing.com?param={{value}}`, + }, + jsonTemplate: `{ + "CertFile": "{{ .CERT_FILE }}", + "Credentials": "{{ .CREDENTIALS }}", + "NestedOption": { + "URL": "{{ .URL }}" + } + }`, + expectedResult: Config{ + CertFile: `"/etc/certs/"cert".crt"`, + Credentials: `env_credentia{ls}`, + NestedOption: struct { + URL string `json:"URL"` + }{ + URL: `https://env.testing.com?param={{value}}`, + }, + }, + expectError: true, + }, + } - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + for key, value := range tc.envVars { + t.Setenv(key, value) + } + + tempFile, err := os.CreateTemp("", "config*.json") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + defer func() { + err = os.Remove(tempFile.Name()) + if err != nil { + t.Logf("Failed to remove temp file: %v", err) } + }() - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) - - }) - }) - }) - - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) + _, err = tempFile.WriteString(tc.jsonTemplate) + if err != nil { + t.Fatalf("Failed to write to temp file: %v", err) + } + err = tempFile.Close() + if err != nil { + t.Fatalf("Failed to close temp file: %v", err) + } - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) + var result Config - hashSrc := md5.New() - hashDst := md5.New() + _, err = ReadJsonWithEnvSub(tempFile.Name(), &result) - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, + if tc.expectError { + if err == nil { + t.Fatalf("Expected error but got none") } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) + if !strings.Contains(err.Error(), tc.errorContains) { + t.Errorf("Expected error containing '%s', but got '%v'", tc.errorContains, err) + } + } else { + if err != nil { + t.Fatalf("ReadJsonWithEnvSub failed: %v", err) + } + if !reflect.DeepEqual(result, tc.expectedResult) { + t.Errorf("Result does not match expected.\nGot: %+v\nExpected: %+v", result, tc.expectedResult) + } + } }) - }) -}) + } +} From 8284ae959cd38f0b8c8cb7b7b711699a21c68417 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 10 Oct 2024 12:35:03 +0200 Subject: [PATCH 35/81] [management] Move testdata to sql files (#2693) --- .github/workflows/golang-test-darwin.yml | 2 +- .github/workflows/golang-test-linux.yml | 2 +- client/cmd/testutil_test.go | 4 +- client/internal/engine_test.go | 4 +- client/server/server_test.go | 2 +- client/testdata/store.sql | 36 +++ client/testdata/store.sqlite | Bin 163840 -> 0 bytes management/client/client_test.go | 22 +- management/server/account_test.go | 2 +- management/server/dns_test.go | 2 +- management/server/management_proto_test.go | 11 +- management/server/management_test.go | 4 +- management/server/nameserver_test.go | 2 +- management/server/peer_test.go | 8 +- management/server/route_test.go | 2 +- management/server/sql_store.go | 22 -- management/server/sql_store_test.go | 302 ++++++++---------- management/server/store.go | 63 ++-- management/server/testdata/extended-store.sql | 37 +++ .../server/testdata/extended-store.sqlite | Bin 163840 -> 0 bytes management/server/testdata/store.sql | 33 ++ management/server/testdata/store.sqlite | Bin 163840 -> 0 bytes .../server/testdata/store_policy_migrate.sql | 35 ++ .../testdata/store_policy_migrate.sqlite | Bin 163840 -> 0 bytes .../testdata/store_with_expired_peers.sql | 35 ++ .../testdata/store_with_expired_peers.sqlite | Bin 163840 -> 0 bytes management/server/testdata/storev1.sql | 39 +++ management/server/testdata/storev1.sqlite | Bin 163840 -> 0 bytes 28 files changed, 419 insertions(+), 250 deletions(-) create mode 100644 client/testdata/store.sql delete mode 100644 client/testdata/store.sqlite create mode 100644 management/server/testdata/extended-store.sql delete mode 100644 management/server/testdata/extended-store.sqlite create mode 100644 management/server/testdata/store.sql delete mode 100644 management/server/testdata/store.sqlite create mode 100644 management/server/testdata/store_policy_migrate.sql delete mode 100644 management/server/testdata/store_policy_migrate.sqlite create mode 100644 management/server/testdata/store_with_expired_peers.sql delete mode 100644 management/server/testdata/store_with_expired_peers.sqlite create mode 100644 management/server/testdata/storev1.sql delete mode 100644 management/server/testdata/storev1.sqlite diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2aaef756437..88db8c5e89f 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -42,4 +42,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index d6adcb27aee..e1e1ff2362e 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 033d1bb6ab8..d998f9ea9e6 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -38,7 +38,7 @@ func startTestingServices(t *testing.T) string { signalAddr := signalLis.Addr().String() config.Signal.URI = signalAddr - _, mgmLis := startManagement(t, config, "../testdata/store.sqlite") + _, mgmLis := startManagement(t, config, "../testdata/store.sql") mgmAddr := mgmLis.Addr().String() return mgmAddr } @@ -71,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSqlite(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 3d1983c6bda..74b10ee44fa 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -832,7 +832,7 @@ func TestEngine_MultiplePeers(t *testing.T) { return } defer sigServer.Stop() - mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sqlite") + mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql") if err != nil { t.Fatal(err) return @@ -1080,7 +1080,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), testFile, config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index e534ad7e2d6..61bdaf660d2 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -110,7 +110,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSqlite(context.Background(), "", config.Datadir) + store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } diff --git a/client/testdata/store.sql b/client/testdata/store.sql new file mode 100644 index 00000000000..ed539548613 --- /dev/null +++ b/client/testdata/store.sql @@ -0,0 +1,36 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 21:28:24.830195+02:00','','',0,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); + +COMMIT; diff --git a/client/testdata/store.sqlite b/client/testdata/store.sqlite deleted file mode 100644 index 118c2bebc9f1fd29751627c36304d301ba156781..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5Piz}ke#beIMbfe*#>plVXQN$8iDR>oY?6`{$qU1kWjbrbmJ`Xwu3-ej5&1|> zM9wfXL)k*G1!yP1Vu512XpTMf)I$&3Vu3d2JrzCdvA|x7_S#zy1+s_!=FPm}d*slP z>)j$DzO>BYy!U(mKEL1l{ob45pPi3xTRzJ-9Jg)y`Q_A0DNRfLc|M;?rQW3f7wEru zyg@Hch!Z+$((6$#-%MR}|8^=6&V3Weyqf!F`ZuQ^PG3n+P3NY*p4yq*o4PQ0LNh16 zPW?9Z75$^fyDz8pYiG5TW!Kqb@6#4_&@j8c!_SST>vTJwQ3{W*^yXg5=pU_Xy{kOi zsy{Y5%=H}GY#C;)#yrpPoqc9|M%QDmVbzm!&ung2HttpOx3+FnewwcyT}?HAcPn2_ zuB>{Y8Z}leJlUDe=ZVVk!Lx6**w55nFR7B}y1?cCemxV3dJUv2DjU3f;vGC@2iBa5}AD)khf|*P2n{hnd~!L38Rb)tvbGsbO;(loLLN zN)MRR?UQN!yslZX#fC-q8Nngx=}~o)fLdvgQ22Q!!;SNo?Z`(_6}+CTYMeZ+dW3QE zCPHuIF~8d}qy(!Y4699MY3w3FY&wSHDPsG~wOLC^syVJOf+{c7X_KhyFL75&CJP3G_n#T`C}o8vQdIHDfk&1IkE=EB-3jM6MRd63TN*XFg} z+1Llg(_Ido{lv_$Q|f0*!8UYI{j34k_w=0;8GU|U`%>r5sxQ}nHKn?(e#Ue7E_tFz zA!kIS3(cIyt*ni@onT2+A4qlC0WC;nvn5lfg&8fo$YaDYFlk{`O$gCl&ulWIWwkA8 zJ1(7Iy3b)oePSSH$f008+;rB%ndHBJkF#brE5<~$~fN$AN z>W1ZUpK4H)A+ri%$-s&*7UjPF7!LQ@1*DF|f-Y-zTc#Ur5)Lf($QWLS#?sM#km5(j zY>x8!vGgYt?V}RIu^~yS)-tWO;;a=Zv(<7QjY34rUe~1UC%)aB3#DT5;_Ax9;>F^^;=&!%+o!Zkm)15mHm+V> z_}Rmcx4!z--BfCtI;Q1*pQ1lJKmY_l00ck)1V8`;KmY_l00ck)1ioVg&QIPyd7KLl z;r{;*soWpFV*^BYK>!3m00ck)1V8`;KmY_l00ck)1dc^ua&qQm`2PP~DwjJ}5eR_* z2!H?xfB*=900@8p2!H?xfWSBq=-FCs=KlS+=5}^=ve|2Avf1O@MBgZHu2ojc*O!*p zZeCtmxqNwf>H2c1vUKCd>gwj|`s!xo`t=*EVRl=7p1uvqW%QjL^u-hOX8+Psac!xz zo-Y@#tgK!s)0_2+m)Dn9*UIJa{r=_UKxBQXv^HAg=vs>O>4GKxmoFCSQv~$)E`NUb zG4B75gV|yXAOHd&00JNY0w4eaAOHd&00JN|nm};>AD{mp%?CLk00JNY0w4eaAOHd& z00JNY0wC~hCve=){|E2?ryu>p0|Y<-1V8`;KmY_l00ck)1V8`;Kwyjs1o!{3{~u!o zV`v}%0w4eaAOHd&00JNY0w4eaAV37L|A!fX00@8p2!H?xfB*=900@8p2!O!&6Ttp| z{B?{Gf&d7B00@8p2!H?xfB*=900@9U@cI9V*?&#t>hy#M2!H?xfB*=900@8p2!H?x zfB*=5+X?hupP0RW|J7`E=1lh5oma9mjg`h~ef7Qa(&}0Qiavk<2!H?xfB*=900@8p2!H?xfPg|E*#95r zO$&Vge;h&l{-44ZCm;X#3c|znss$B&$jB1jSh1?$2MDrS*tP6Gkj;C z*`5*n$f9A@L&SgDHj;aAW^=o;aj%lUwRNNN(|q;lu2%DRxAN8G%BmNt(SWLjCp)tl zeSKZ~e8V?)TP(Ria%wK>)4h%Bw=2A@KWb$OdZ^|v(9o-SpFQ>$<1s|1w%KMPYtQud zSUuF9NMnyXmTUT!V;gnTC(+>9x7uv6T5}n_t{c0DLRh4jzAV==JwLDo>Ox%j!|hvl zHntz+Kd3widskT8yt}n?Z+qj`*1deSvCoaiGdiZ>dq%UGzj=4Na_jxAs6ecU{C4GL zWxKMqS=q_+p?Rbx8f>AkJCo5b>soIrw$h-G7nrEiwSBTnem*;FB7<_FQSel%ykDLg zIi1$8>Z9gOSWk%`W}aR>mCwxrJ37F5k! zp4SZu2bPz=QMtKs|MtE70{wXE!eZ63eb$UD%ViDHM>WsWn46F(3Hyf~G5DjcC}J1- zq69;d*cKF95(Q&(njSQ#4pYsEpPw2wr$IU4W2p3iIo&>)*3avj6OXtv5rby}w1aNM>cHBgV^83!~na!`uYBR?^@Yb7R!ok`5H zng7Uf_YJdNcPX|pbrjyA=d*S&11$=FG)Z=utMnrh=(tY3Tl1Ca#)lZ?w@tg-Fl&C- z9VpN21Iu;nHnYj6h6>>28dj^{6&ZJ(0&mOQ?V}C!Yk0otcY_Ra#U5kI+?54tj%~9V zd39;V%lUlb8$2kClyTXoU5i*5W`~A-V3Lw!?-95+KGh>z^7dp)^5~S=K}40?m)Z;i zQG!S?u46hS*4bm^A(d{CmkXczFAAT6gk{Oqvq6OL+xi;{-L7ZIH$=Ro}u`- z=kQ&8h7Je8WBaZ`_E}*nti8k!ClAsY{o1_NI~)6;c)H7>v!9q5c1rzBDcFV%s-HFB z`kuaXBBRgGYhUWzS@q@muclPD)z5ei-z85JDUOT?U!j@PxRvEhw-c;l>I11RJD^3c zY_??Tw3eck5_ya`uqCa9stIwo>zPevw5+y8ZO5e(O!qm=s88$#Y-xEat-n7%YRMxt zQR=@=?@ng)@WqE(MxUG0enGB| z9Egw*#|hPXM6zl-V*Ps5iZ*I9^2gzsK^}xP9XDPK)50vu>9|(gB)1VYD!D@uT&z0z zcdKFboxjMTqwxD8i9GyZQW$?EO9`R?DZsbvCUwK|xKA}G%8*$Fv1DMG7mIRVe+-BF z>;h6pVnLTRyDig=*7XM#dt?kRE@SCv&q?v4V>U;5{aE@FiuO^7;nmX`vV2yh8OE3PjT7pq-z2^0`5@*s|@HZVMq3zsI+`sUoIxh725 zD2Z2-wA0Q8yX~3m?^1f|$Ej1L>0h1vuPfCO8=zT$?q z)zYA!y&Ky6L0bPwX_RFSvo%V?Y4eRYGx~dHwcg3txadNYAFLbWEqa*n!(k&Dk__WQ zm^@%CPtU!P(a)aMz6^~elnI~bRHKNB=jH1q{muwIMt)eF4k&S-v5vtjk`zrvq*ZT% z^#txjRLEV6YC6L8hijI2J46)=wg4ai}ufsFd0=^vKPu2a6eB@e~&LrpYHdKg``9adV-Mz zl^o@Sp4UP`{gaJ*!9@n6w#RDwLa|3{F(MO9EZ#r$ifkVpX-mE+D@sP1EL?i^S&JG$ z#Rzpq4ZVKqFTzDlWOM#Yo2&+ypt`CDAMVTIxI z%c^3dww^<+^3~Htx+xJ{FpT_zSWX7lDgGZWo`!vY0oVTA_l%P7o&3q}=X{V$>&E&> zXF}u?Zg0^glce;QQYr1dKibWqgoahuFmOdpjvU{u`#GPS%jlObX}!sCNk4emRV;zx z>#!Hp`8AXuVB+hp{^e}#rwJ%DZ&}O{Qmvd<&5nMy)iUs0u@Gju}|KbV;jvxR6 zAOHd&00JNY0w4eaAOHd&Funxv{eR=DTZ{|@KmY_l00ck)1V8`;KmY_l00aa9-2WFO za0CGm009sH0T2KI5C8!X009sHf$=4P{r~vt79#@z5C8!X009sH0T2KI5C8!X00BV& z@BbGha0CGm009sH0T2KI5C8!X009sHf$=4P`~TysTZ{|@KmY_l00ck)1V8`;KmY_l z00aa9?EeJ`96s009sH0T2KI z5C8!X009sHfk6V;{|^$uIS7CN2!H?xfB*=900@8p2!H?xj2{8K|9|{+iV=YT2!H?x VfB*=900@8p2!H?xfWRPu{{>_kT=f6| diff --git a/management/client/client_test.go b/management/client/client_test.go index 313a67617db..100b3fcaa12 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -4,7 +4,6 @@ import ( "context" "net" "os" - "path/filepath" "sync" "testing" "time" @@ -58,7 +57,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := NewSqliteTestStore(t, context.Background(), "../server/testdata/store.sqlite") + store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -514,22 +513,3 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match") assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") } - -func NewSqliteTestStore(t *testing.T, ctx context.Context, testFile string) (mgmt.Store, func(), error) { - t.Helper() - dataDir := t.TempDir() - err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) - if err != nil { - t.Fatal(err) - } - - store, err := mgmt.NewSqliteStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err - } - - return store, func() { - store.Close(ctx) - os.Remove(filepath.Join(dataDir, "store.db")) - }, nil -} diff --git a/management/server/account_test.go b/management/server/account_test.go index c417e4bc89b..4dd58e88e0d 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2423,7 +2423,7 @@ func createManager(t TB) (*DefaultAccountManager, error) { func createStore(t TB) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 23941495e8b..c7f435b688d 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -210,7 +210,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { func createDNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index f8ab46d8176..dc8765e197f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -88,7 +88,7 @@ func getServerKey(client mgmtProto.ManagementServiceClient) (*wgtypes.Key, error func Test_SyncProtocol(t *testing.T) { dir := t.TempDir() - mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, _, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -413,7 +413,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), testFile) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -471,6 +471,7 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie } func Test_SyncStatusRace(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { t.Skip("Skipping on CI and Postgres store") } @@ -482,9 +483,10 @@ func Test_SyncStatusRace(t *testing.T) { } func testSyncStatusRace(t *testing.T) { t.Helper() + t.Skip() dir := t.TempDir() - mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, am, mgmtAddr, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -627,6 +629,7 @@ func testSyncStatusRace(t *testing.T) { } func Test_LoginPerformance(t *testing.T) { + t.Skip() if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { t.Skip("Skipping test on CI or Windows") } @@ -655,7 +658,7 @@ func Test_LoginPerformance(t *testing.T) { t.Helper() dir := t.TempDir() - mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sqlite", &Config{ + mgmtServer, am, _, cleanup, err := startManagementForTest(t, "testdata/store_with_expired_peers.sql", &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", diff --git a/management/server/management_test.go b/management/server/management_test.go index ba27dc5e885..d53c177d6b8 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -58,7 +58,7 @@ var _ = Describe("Management service", func() { Expect(err).NotTo(HaveOccurred()) config.Datadir = dataDir - s, listener = startServer(config, dataDir, "testdata/store.sqlite") + s, listener = startServer(config, dataDir, "testdata/store.sql") addr = listener.Addr().String() client, conn = createRawClient(addr) @@ -532,7 +532,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSqlite(context.Background(), testFile, dataDir) + store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 7dbd4420c10..8a3fe6eb049 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -773,7 +773,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 225571f624f..f3bf0ddba78 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1004,7 +1004,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1069,7 +1069,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1135,7 +1135,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1188,6 +1188,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) } diff --git a/management/server/route_test.go b/management/server/route_test.go index fbe0221020a..09cbe53ff53 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1257,7 +1257,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSqlite(context.Background(), "", dataDir) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index fe4dcafdb26..615203bee38 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -911,28 +911,6 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data return store, nil } -// NewPostgresqlStoreFromFileStore restores a store from FileStore and stores Postgres DB. -func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics) - if err != nil { - return nil, err - } - - err = store.SaveInstallationID(ctx, fileStore.InstallationID) - if err != nil { - return nil, err - } - - for _, account := range fileStore.GetAllAccounts(ctx) { - err := store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } - - return store, nil -} - // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewPostgresqlStore(ctx, dsn, metrics) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 4eed09c69b6..06e118fd22b 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -11,14 +11,13 @@ import ( "testing" "time" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/testutil" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" @@ -31,7 +30,10 @@ func TestSqlite_NewStore(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -39,15 +41,23 @@ func TestSqlite_NewStore(t *testing.T) { } func TestSqlite_SaveAccount_Large(t *testing.T) { - if runtime.GOOS != "linux" && os.Getenv("CI") == "true" || runtime.GOOS == "windows" { - t.Skip("skip large test on non-linux OS due to environment restrictions") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } + t.Run("SQLite", func(t *testing.T) { - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) runLargeTest(t, store) }) + // create store outside to have a better time counter for the test - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) t.Run("PostgreSQL", func(t *testing.T) { runLargeTest(t, store) }) @@ -199,7 +209,10 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -213,7 +226,7 @@ func TestSqlite_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -271,7 +284,10 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store := newSqliteStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -293,7 +309,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -324,7 +340,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -332,7 +348,7 @@ func TestSqlite_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -345,11 +361,10 @@ func TestSqlite_GetAccount(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - if err != nil { - t.Fatal(err) - } - defer cleanup() + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -369,11 +384,10 @@ func TestSqlite_SavePeer(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - if err != nil { - t.Fatal(err) - } - defer cleanup() + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -421,11 +435,10 @@ func TestSqlite_SavePeerStatus(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -478,11 +491,11 @@ func TestSqlite_SavePeerLocation(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -532,11 +545,11 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + existingDomain := "test.com" account, err := store.GetAccountByPrivateDomain(context.Background(), existingDomain) @@ -555,11 +568,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -579,11 +592,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/store.sqlite") - defer cleanup() - if err != nil { - t.Fatal(err) - } + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + id := "9dj38s35-63fb-11ec-90d6-0242ac120003" user, err := store.GetUserByTokenID(context.Background(), id) @@ -598,13 +611,18 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } func TestMigrate(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newSqliteStore(t) + // TODO: figure out why this fails on postgres + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - err := migrate(context.Background(), store.db) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) + + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on empty db") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") @@ -640,7 +658,7 @@ func TestMigrate(t *testing.T) { }, } - err = store.db.Save(act).Error + err = store.(*SqlStore).db.Save(act).Error require.NoError(t, err, "Failed to insert Gob data") type route struct { @@ -656,16 +674,16 @@ func TestMigrate(t *testing.T) { Route: route2.Route{ID: "route1"}, } - err = store.db.Save(rt).Error + err = store.(*SqlStore).db.Save(rt).Error require.NoError(t, err, "Failed to insert Gob data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on gob populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") - err = store.db.Delete(rt).Where("id = ?", "route1").Error + err = store.(*SqlStore).db.Delete(rt).Where("id = ?", "route1").Error require.NoError(t, err, "Failed to delete Gob data") prefix = netip.MustParsePrefix("12.0.0.0/24") @@ -675,13 +693,13 @@ func TestMigrate(t *testing.T) { Peer: "peer-id", } - err = store.db.Save(nRT).Error + err = store.(*SqlStore).db.Save(nRT).Error require.NoError(t, err, "Failed to insert json nil slice data") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on json nil slice populated db") - err = migrate(context.Background(), store.db) + err = migrate(context.Background(), store.(*SqlStore).db) require.NoError(t, err, "Migration should not fail on migrated db") } @@ -716,63 +734,15 @@ func newAccount(store Store, id int) error { return store.SaveAccount(context.Background(), account) } -func newPostgresqlStore(t *testing.T) *SqlStore { - t.Helper() - - cleanUp, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUp) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - store, err := NewPostgresqlStore(context.Background(), postgresDsn, nil) - if err != nil { - t.Fatalf("could not initialize postgresql store: %s", err) - } - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - -func newPostgresqlStoreFromSqlite(t *testing.T, filename string) *SqlStore { - t.Helper() - - store, cleanUpQ, err := NewSqliteTestStore(context.Background(), t.TempDir(), filename) - t.Cleanup(cleanUpQ) - if err != nil { - return nil - } - - cleanUpP, err := testutil.CreatePGDB() - if err != nil { - t.Fatal(err) - } - t.Cleanup(cleanUpP) - - postgresDsn, ok := os.LookupEnv(postgresDsnEnv) - if !ok { - t.Fatalf("could not initialize postgresql store: %s is not set", postgresDsnEnv) - } - - pstore, err := NewPostgresqlStoreFromSqlStore(context.Background(), store, postgresDsn, nil) - require.NoError(t, err) - require.NotNil(t, store) - - return pstore -} - func TestPostgresql_NewStore(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expected to create a new empty Accounts map when creating a new FileStore") @@ -780,11 +750,14 @@ func TestPostgresql_NewStore(t *testing.T) { } func TestPostgresql_SaveAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey := GenerateDefaultSetupKey() @@ -798,7 +771,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") @@ -852,11 +825,14 @@ func TestPostgresql_SaveAccount(t *testing.T) { } func TestPostgresql_DeleteAccount(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStore(t) + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) testUserID := "testuser" user := NewAdminUser(testUserID) @@ -878,7 +854,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } account.Users[testUserID] = user - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) require.NoError(t, err) if len(store.GetAllAccounts(context.Background())) != 1 { @@ -909,7 +885,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, policy := range account.Policies { var rules []*PolicyRule - err = store.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") @@ -917,7 +893,7 @@ func TestPostgresql_DeleteAccount(t *testing.T) { for _, accountUser := range account.Users { var pats []*PersonalAccessToken - err = store.db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -926,11 +902,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } func TestPostgresql_SavePeerStatus(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") require.NoError(t, err) @@ -965,11 +944,14 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) existingDomain := "test.com" @@ -982,11 +964,14 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -997,11 +982,14 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } func TestPostgresql_GetUserByTokenID(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skipf("The PostgreSQL store is not properly supported by %s yet", runtime.GOOS) + if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { + t.Skip("skip CI tests on darwin and windows") } - store := newPostgresqlStoreFromSqlite(t, "testdata/store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanUp) + assert.NoError(t, err) id := "9dj38s35-63fb-11ec-90d6-0242ac120003" @@ -1011,11 +999,8 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } func TestSqlite_GetTakenIPs(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -1059,11 +1044,8 @@ func TestSqlite_GetTakenIPs(t *testing.T) { } func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1104,11 +1086,8 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { } func TestSqlite_GetAccountNetwork(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1130,10 +1109,8 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { } func TestSqlite_GetSetupKeyBySecret(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1152,11 +1129,8 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("The SQLite store is not properly supported by Windows yet") - } - - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1187,11 +1161,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { } func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { - store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite") + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } + group := &nbgroup.Group{ ID: "group-id", AccountID: "account-id", diff --git a/management/server/store.go b/management/server/store.go index 50bc6afdfd2..d914bb8f7d5 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -9,10 +9,12 @@ import ( "os" "path" "path/filepath" + "runtime" "strings" "time" log "github.com/sirupsen/logrus" + "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/netbirdio/netbird/dns" @@ -240,28 +242,39 @@ func getMigrations(ctx context.Context) []migrationFunc { } } -// NewTestStoreFromSqlite is only used in tests -func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string) (Store, func(), error) { - // if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE +// NewTestStoreFromSQL is only used in tests. It will create a test database base of the store engine set in env. +// Optionally it can load a SQL file to the database. If the filename is empty it will return an empty database +func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (Store, func(), error) { kind := getStoreEngineFromEnv() if kind == "" { kind = SqliteStoreEngine } - var store *SqlStore - var err error - var cleanUp func() + storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName) + if runtime.GOOS == "windows" { + // Vo avoid `The process cannot access the file because it is being used by another process` on Windows + storeStr = storeSqliteFileName + } + + file := filepath.Join(dataDir, storeStr) + db, err := gorm.Open(sqlite.Open(file), getGormConfig()) + if err != nil { + return nil, nil, err + } - if filename == "" { - store, err = NewSqliteStore(ctx, dataDir, nil) - cleanUp = func() { - store.Close(ctx) + if filename != "" { + err = loadSQL(db, filename) + if err != nil { + return nil, nil, fmt.Errorf("failed to load SQL file: %v", err) } - } else { - store, cleanUp, err = NewSqliteTestStore(ctx, dataDir, filename) } + + store, err := NewSqlStore(ctx, db, SqliteStoreEngine, nil) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("failed to create test store: %v", err) + } + cleanUp := func() { + store.Close(ctx) } if kind == PostgresStoreEngine { @@ -284,21 +297,25 @@ func NewTestStoreFromSqlite(ctx context.Context, filename string, dataDir string return store, cleanUp, nil } -func NewSqliteTestStore(ctx context.Context, dataDir string, testFile string) (*SqlStore, func(), error) { - err := util.CopyFileContents(testFile, filepath.Join(dataDir, "store.db")) +func loadSQL(db *gorm.DB, filepath string) error { + sqlContent, err := os.ReadFile(filepath) if err != nil { - return nil, nil, err + return err } - store, err := NewSqliteStore(ctx, dataDir, nil) - if err != nil { - return nil, nil, err + queries := strings.Split(string(sqlContent), ";") + + for _, query := range queries { + query = strings.TrimSpace(query) + if query != "" { + err := db.Exec(query).Error + if err != nil { + return err + } + } } - return store, func() { - store.Close(ctx) - os.Remove(filepath.Join(dataDir, "store.db")) - }, nil + return nil } // MigrateFileStoreToSqlite migrates the file store to the SQLite store. diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql new file mode 100644 index 00000000000..b522741e7e0 --- /dev/null +++ b/management/server/testdata/extended-store.sql @@ -0,0 +1,37 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/extended-store.sqlite b/management/server/testdata/extended-store.sqlite deleted file mode 100644 index 81aea8118ccf7d3af562ddece1f3007f1e8fc942..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5Piz}ke#beYMN+aQ#_=W-C(-UmiDMI%Y*LhD$qU1kW!f=p%ZX%R*D!+NhyTukw&U@-<(bEEZE!smbz2#6Km;T zzyElfJ{<86=v@kZ9`y0u#5wD4N4aqJTQ2iP_S^B_9lJk%F*Q1#9sPQAGr2Q*CV52C zM!rt`KJg{}qsQ6TQ}X4LQbISX?2-MTK@C*3wquH8wQia1maP`t_Y<9|*V6LGD^ll@ z$kwZm)E2XB)6g2KR;e)CRvmMf8MfNCnWgI0_}t@b8>Q8|CFREY)zVLu^1;=Vl{@Q7 zIli*;nX-4Xa_;fwWLjQXkv?B_wCx6qFA$%a_1biI^~$Z%z;h{QsI_IqVUL{I;I!VM zp*5MGwWHZPtm?w>)7YbyZfTBg8fsN@NR+=idXvqTD;A^ARdsvc7xq&$CoI>{Y=@(z zCIp2)+PHChb>p6LvvkjGU2b;m&idxvjny0Lca?H&SLni4TbkQj5wa`H<4aB*C<3PwX|EloO^QVXj+zK>FGUT20}!<9E&R1EgR8ks2WeLme~l@ zs@p0xc2BP`0ky0c?IyGI%52%-T4fC;P8doLi>qM^;bY7jUcSz^8)jvf>o%xW=(}Oo zbt9zkz*?fVq1$$w7v{F9TrFK&-MV#GnbBJM%xqaV99H*i&tf%#TvkLHvtlws!p?rn z@BG1b=CN}Je~hVx*ybYE<`?u0E_W;?Uy=tIoY0zD5N{LP9jD;Up74-MOBD;GEw6cY2^cq zQ=G!}MmBTWEj5&&w|v8Dl3A)drtO3cP$bq()wCmGyUa3JBa&1xE#C;rqExG)IXs9( zqUDGN~+v+LLj6K~ljV3e5LAweNYOnu1c6QA= zluMDr+E*)PH89j}rta)V^c&>VnJF?N?(jL>oEV|~5p{SsmwA?(b8DwCN~36OFO`-r zPfMMXfe#9XdpLA<6Vv@pshb%JwyuNfX7#wfCm$S1%hS`+Q&~8xu3Y!iXw+?WGoHhD zg$GI~*6K`c z=uMr@9h8nS-RCf)F0mJ|rQ&ExzCJxDWQI;r=zdOZC)09qO6o{~Rr!%vt(r|uH~hmR z{nq4Xd6KRj`H5sjZUyDs7avJ!d1^}f1-UwMAijkEJ`%MaKRIeU{$%)|>9KcC&mX%- zDtQoAH?3f1O>;Oer)B9)jogNJ(vUm!gNt4z|E|~cuJiYE=-v1GeiC{3-Xu5vge-+e z0aAdY8+B@iZVR8PQj{UH;<2P>Mi__+Uq6JyeRhFRLjpmI)!Plt@)ikuI(w-0&qD)g zZ#_uyqovgcdHq28Q;POpiT>CSCsk=^dNbmzBT`zUVLlv$h^D=kM$1tOWPX5)30gsY zuDCxjCYL|~Q6~>#TG12q59HkWk(9hPHE68)Caf0xPf5PcP9%OXIYF;d_G*Hjcz^&1 zfB*=900@8p2!H?xfB*=900=yTKxajIV{78})H`EaKbgoJ&x~d=8NROGt`#cF`Q?SV zZEZ0>w@_gh=2jMpD|5w#rNv@CU#L}z+w=MPg}Fk0EtFu*P9iZ*?Ub^AO3)Jz5C8!X009sH0T2KI5C8!X z009sHfme*c>EzbfF#8Yg`v03m_M2C%gJ>=YfB*=900@8p2!H?xfB*=900@ADm`X{~f+ZJ3L{!_;$4Yd4&I zg}%f`d8j)(3Vo}T!XD{1|Gof4{3X4`=qo7bdlcwa_QG6#d9JXc%;zsIEM1(Zd(ZQW zE5)Vd`FZzl_W2?gS(z&=4;DGNmOR}aKPR3_K2P_^)AOu&y+5OESE@6gJU%0FV()px zE)7NO%6CNU%2ih5%s8@_2~DNOY7guS`6c_ouEFYyEA{#M%qRT%|48;IQQt!M)XfB*=900@8p z2!H?xfB*=900=|~@b&*N_bG_=|1g65{(t(>KRiGH1V8`;KmY_l00ck)1V8`;KmY^| zF#&%4AM5`^tYCBv1V8`;KmY_l00ck)1V8`;KmY`Y0M`F70}ucK5C8!X009sH0T2KI z5C8!XIQ#^#{y+RWMh`&%1V8`;KmY_l00ck)1V8`;KtTNd-{e0fvg`DM2MB-w2!H?x zfB*=900@8p2!H?xygURtZ;ecDZM~7nOdQW#zI`GyQCp}jRhKTz&n*>e+jE5itIVzB zt4nkF`Gt9{QkbXT3oH)vU3hJqS@%B~`0Il7`-Tg1g*;WO6qYXL3m1#a=jIFfr3=gM zj~i+&y>niorpeT_OK#IxRjV>rc;rAjf#`pifJcf;yfB*=9 z00@8p2!H?xfB*=900@A9M}U9-|1jGf;`{%H*?hqIAGQSoAOHd&00JNY0w4eaAOHd& z00J)^f$!z_|D^0X-T(j6^%V^Q0T2KI5C8!X009sH0T2KI5CDOflmOlTKen|sarW{X zmoKkWKP@iX#l>gc`ajI2{7v&V`%v55DOK;9H`zM<6+)5z1YvIef|Ac)6wkBbmH+jE zn^~C47kV=NQhgcx`oEO@Q-Yp&fB*=900@8p2!H?xfB*=900@8p2)tqhPA9jQsCyr+(`oC#^zCO%v2&_J2w%KIYEwkO4r(XoP z|H;P4;5v)XU1vem`JqI0o``Sf!f!3m00ck) z1V8`;KmY_l00cnbl_Bs}a#4!gp-9GO6zd~$a&fG`#6o={``?L4={JdE-(>&g*ssQa zckKT7#nk9{cJ%Af&E#Jvcale>-;8{n_b11|@ru;BB(nADBelgW+cdO>s#Pk?wpGX6WrnTtFE3E_s;hwS z_Z;ypIKH+~TD@CRZmeG|{ZuI*+|;siXI&}BS5`h#_BvF~J>Hy5%PT9==c|sk-C*$r z;#0F;o9?b&xm6P9`c5m0(L-4|LtQT`4twOx27T}j4Xw%itR2nXVO3Xqej0n!(k;!= zO+&3}4vF$tM{ly(a>ZiwxvFmO`@(*T=7i-Mn(c5~pe6)`Kiar)dv)WUaLZC!45 z?aun<-Hp{7>vxrMZC4nNt+q7PvDJE6xprrxbmRKES0GS?vQfHL+9<8Bl{OX8HJj8# zoz3O8C(`nwEOkZ$E9Hf3ZlY$}aL6vj@npY=^vbzLAyTRGZh1P%v6OsC9yD&;dJ5ug z;>o3>X<3$~r}u=#h(z&nENTxTEvm8589k+0W+PCVZmZOr`hrBo1eT{B6%*rm;Tu`gfX2Y!OMo683Ekrf1+jg55=9Z^iEnQpP zx^-8Xp)Jap*|KgptnOKs#cHIFvLe!$6_XhfcJ^C-=MT0bkDcp_Vhkz7HW#razhGca z5GMMwww02dkTD2&) zv1l*secNG8J_26z8nH6b=C(yFYYP;e@hMVZ3m*3Qkc1^1|ZL6m|Gxl`LG@8sH zpXw?=l&k5DZd1azs}y(}+IAN$*RN_jn$zYP)y1s=fa6aMv~KP%cFdYhSIH)xc1@nYyzd(QlAbXQs%A__*irU1Egx2f?9z zS1$9cFy+=>5N~69skD4~TI!q(d{8jl!=baAnC^E<-ONz1bsbbUtH<>{`QS)eo}QMT z%EDQ7<+`6nqi(C4@f^M@JWxV$WI*_G&7Arjna;FZd=^vfNwwG>O?ty-OQue9DViyf z#|Q#j++3&}6L;ITR%dEMZ|ZdJpmdDsK8G1~iM@a=6-QI@_31%Vp3o@@-Os7*WLhpx zNgXM$DnGocRkNw-hJSdZ-trWH(vX=3K(v@E@;k=yW28ghqzaM7#e-}Rc_b^d-1 zz59OOPa+TBo8-oykfrb_KnielqfX7xZQ)Z@iZWzYJeKrK^8!)f>xXc-&n^&ZNFZpj zdb^=n-n@QKXAjl>$z>qztvM-vw6yvluOCQ%O3~gc(H|S)q$&+fZ$_MTL`rKk%!h*z z(OlNjXvs@~%nxueK`W@w758Vmf}L8D|&+dft))(l9Jb^28}h}gw=xoDM>5s z3}0NA4mQ@IeFx(k-ty;+lZF{YjB?f!He1D1(bpZVWQCHCIR4~>;E9EM)E0*5k%i?gh z2PPPvE+AjRa|5B_@A`*8JkEACtFq(y5xd{a`i-V{8SUGKI|1QKzpmo?merw7KYKN_ z_QRC?lfod&>}RVM`qQRs@22GoC#B9t`t;UKRx(7;@H zFEF(@`_8j=A-(Dg?)wivJDHYOPf4Arz<63_L$BRA~+fg{XTgY2Y zhbknK3aTvUo=9?9K6Og^qU0{ELxSN7-;P>WI4kOGx>b8q3O-tVbBM8)evQ)7G31M= z2t|GoHUsBkn{A8q?AN-jS80tVN<~&e6mcE1Vg?KImT5UbA%37*no$qxIoJV-zUibA zp_tx{yl&5<4#zJ=-1%-E-cgR+_lxswbiSo+=aJN^379&T##3iRqKwuCb7me+D_mNO&a&IicsZ z5Lf?X<2K)9P%AsEvg<4MU@dxNyn!X!7u}KVq9bjEH)Z{jo+fkW-+0!f#;2mXI-?W4 zb@VUYNlooUZ@B22f4B*6_Uo)q($$#f2cw(1&vl{V@%zU`Z29~>pV{nD(Q!XaCvfrVORM$ zIg^%iIq8XaAu-~eqZOZu_BWatwTZA1#7rMT4e&X|1UxNka&Oq z2!H?xfB*=900@8p2!H?xfB*;_UIP63Ki2<;SHI{P2!H?xfB*=900@8p2!H?xfB*<^ z0$BgU4nP0|KmY_l00ck)1V8`;KmY_l;P4Z`{r`tw$LJvlfB*=900@8p2!H?xfB*=9 z00`jvKWqR5KmY_l00ck)1V8`;KmY_l00a&{0j&QIzmCyE5C8!X009sH0T2KI5C8!X z009ud{r|855C8!X009sH0T2KI5C8!X009s<`~-0Q|M2S=Jp=&|009sH0T2KI5C8!X z009sH0j&RF10VnbAOHd&00JNY0w4eaAOHd&aQF$}{{O?TWAqRNKmY_l00ck)1V8`; zKmY_l00eOTA2t92AOHd&00JNY0w4eaAOHd&00M`f0M`G9U&rVn2!H?xfB*=900@8p z2!H?xfB*>a`~Rivw+VXU0RkWZ0w4eaAOHd&00JNY0w4eaAaM8yB&B3x0)PL1_!NsC zfdB}A00@8p2!H?xfB*=900@9UF9EFodx_u}1V8`;KmY_l00ck)1V8`;KmY^|9|8RR z|KZaqdISO>00JNY0w4eaAOHd&00JNY0=)#V{_iD%V-NrV5C8!X009sH0T2KI5C8!X lID7plVC(-UmiDMI%Y*Li=$@Av+Fe!utoy*D$Iwmy1bdQ7X>PTladeBzaaq9p!8(-MiqJM{lD{THvd z=))OtLPtgV-0$N%iL1`vjrhXp?|qrq)8CK(?%c=YtI3h^^vE|OTf=)Jmxs?N#?Uv3 z-zUDNfAqTYYEr#XnrG(TdW#!L!Zs zQ@z0)*S3tBZj?&Qb#>3)XO^osUFPU!xp(gQ_07WCc0s$hal7zyt$2JjMeWgsR_t9_ z@p3V2teAPWHJMUZR+KN+JY%=UdKc)OnhyK4y>{zCp>JE-W$JBF^Vm~wHflCJwTwCw zS$l@N$I1Z=k;a}jOvms{+tSO1N230_XV%$lvE(rNT-J9Fg|J95JXx-0xSo%ex)2rq zaP!{7wav%c{lep*cbVDsM;lw)n``$rwzXnqpX%fwVems+KkaKXJ(70<*{mLdk(7*G6JblR&&uKOytWREittuj`?YqAGYjqjBTpX~koVlQ$7~ zBbRy2hAt(T4PjVyGE05Wc0Jhud1BSpZ8s*i&m4=@Vo4?25k^qtr5ZKE^MhDSs-qs; z)emT7a-4z*knSx@?*lnZr z^{cy{;WhmX^6M^Rv1u*~lx)jlB?=LebCC17kI;F9=_})~&zdH&(v1cU`@kS2M~=py zOx!Dkw&d=~mgLeYv;2rEw=cCBM716w{Gg@|$$B1HJuZ2~yM~H5^MwRI`vu;w` zQRyD0dkQmZ6MF$$%8w+~yVLzbrqo2P{W-ZioKo{sN=u2XN&50#WUHKqKDTpc+OAt8=KaqAJuaoZ7#;p3LaVVj{p4$gG)AgpRT z(aM_EaA8ivG3y4ojj&P49g5&$mdU@H6|?R9MGhT>-xo>b;X9Lp_#;`09|cGOo@rI7 z8>Y*Bs!mac%*u}?9V@~}l>7PwKSZ?bbqv~$QXmrnu1C;{>I=lh!kAnF z1w@lPi0#B1=$^=#YePwOeX8GF6DF+Z#HV3@pPfkjXmWzymGtcd{on-xAOHd&00JNY z0w4eaAOHd&00JOz1cBCy^7_ui!>PB&c78T7dVX|dbad3;SMOGGrRD7M!rZQLJv+Bh zVoP%?3;C6~{KDe(d^Ve_l=8dt+4+UJTy`!yujLk3v-#ER;?B7(fT{SXzl~iQUSm0I8Q)y0LY*{j)^*_nri zyH9D?t}U;vt=+gW^NWu^+4%Z5j}nP->Zg+aHbFmlfdB}A00@8p2!H?xfB*=900@8p z2)twjE)DODo#gsMu>b!qk^b%_8z8z10w4eaAOHd&00JNY0w4eaAOHd&a3q1V!&778 zxq{&N|8ydqKC%chK>!3m00ck)1V8`;KmY_l00cl_kO;IaB|WjT^X}Bv*4F6g&GVz9 zC%K8fHov}HSe(B#mtVefeQx3U_59qee6BEe`}X4E`r^vsdg0cs+pJkuFliL`q}F%`Nie=`QY*Y`MfW(GM8KKFLHb>S^BcU z9RJa>S^81|{aoSiAAb_u|Nmnm{l`I4BgO&(AOHd&00JNY0w4eaAOHd&00JN&2n-KT zjN$wLf(DKt00JNY0w4eaAOHd&00JNY0wD165jg4h|NZCx>5u;51p*)d0w4eaAOHd& z00JNY0w4eaATYoL{QLjd{|~T&F*FbW0T2KI5C8!X009sH0T2KI5Fi5B|HBMG00ck) z1V8`;KmY_l00ck)1VCW$31I&}_&UZ2K>!3m00ck)1V8`;KmY_l00cn5|Nj5b{kx~8_f#>z&8u?6;`T{Fsc(|WPK`S<^a(*Kx9e@kz8 zfdB}A00@8p2!H?xfB*=900@8p2>cKTTpr$0PI6&m$p8KSC%Mdk{r?X^t>^;?fB*=9 z00@8p2!H?xfB*=900_hg`1}8pJhZ^~|0fajpZ{0V8wvWs3j{y_1V8`;KmY_l00ck) z1V8`;K;T3KTGy1Zot=p*H($SbbEW)ge%Z}mKk6a>libSRvLCV!jIF&wdE36vHt72X zdHS-!-29T3&93sFEBxL6o2}90WiB1V8`;KmY_l00ck)1V8`;K;We!;NSnp{{N*?89ECBAOHd& z00JNY0w4eaAOHd&00O5bFp>VB#H8}u#JTU%|9bAP$A5S3n~|;IzZu>e zKBN40=$pjv6JOImdR=)nsouPxBuuN!p1PmaOpobC)3f=xUUlqd!_{-a@wL{}D=GD( z6{U59XPf1xdV@KxZ5cJ)D3zG&>YlyNELZn`d4X=0gYN%%+vwec^Xr?1we5m-Z{v31 z=UVajt`@aN8(OheIr|q>{4+^}kKWSw> z^ib3;)6k2W$DVq#(HO#0%cwJvwP(0{tQ=@hq_L+B(=j~Lw)C>$k*NRfnRPZ>EIEum zm-XF4AuLi1PnN40uIJkVbs;MJ;pV-EYnzX?`-R6r?=rLNk2bcpH`nfMY-`2JJ~tj$ zZy37g>eZrl=h0^2-rbF`K%@w5vv8-dS=d-FY-xOGE~$wIo5}1>q}1!G(i(}Z)Gy@v zCTce=kL;44Pj;I~r(9qZJe4YMm#0R~CDj{hzj+hXljDbp=Qqx#R8>{Jddw|`C-S#* zaeLrtag9kWhDvknTBJ17)oEM@W{GiRMa^o~nPZk_i-y1QfFifam|~q+w=?jmZ#k=+*#Xsu&vF|A5Wc`Et;0cs-a~$ ztU~%IYCMfOJu)R>@30{Tf4mii>;hlZ!;mDlnV2n!f{{6mcbZd!#m$MIpY1lMPC4OY zsC0)pJs3->msG`!EH)^*&uGLXJ>IWw5{OsoCuDw^N^y((bv^V^R0XeRG;VP`?Kq3$ zO#}|aWnQzPO9^H}SWBI3P2aO!Pqr$aShaQAjfw3u$6~cuQpt9NW{bR3qh@#&+o{K- zI_j}q{eVVB4k{+~#Os;dG`-iD@lm*kuE*;B3^XbH(InYr zPOKlHK*O=i&5{?JZhVMge%-K|6{F-eosRO%IxrpEsxymxYM=mKu42~OU6FB@De%^e z-8NcZzq;!gUenJYSL`wto4c|=$+j$3BCjsZcsZX>e4Ph{zA_H`tZ5P}-DuFT4-8Ur z)h!4vzg%;-^JDyTij4`T<(lzMYo zXa>=kl@fW3D6sWf3l)3B-KJ|)nO-yN zCbb=v?qRy8FrzlH7qF%LNK(B!-EYYwHIZw7PVNq;)clmvQX;Do;Z-l&b;Gp8>6vb8 z5?P^SU`Ha6tSG3UnECQUC8bVHDZe6DM-D_th~rS)dPH*EcEtMixD{>KX6TQDGo3sL ztJ+Sq7^a0;nA32~x2jZ{QNnSf3F|rWX_$7}qyBDtV)PFQHSs~>Y;OEFW1o)Q z82$IuUyc54YHFmI{I4_rIXrphUxxlZ@gGA*;)C?(lm9*O50n2iF@5fzsqi0yy}EdB zz44Qj`Y5NgR-%dPdr-dNpy7we!{#Zvb!X0WPv~wTF)?_2^f%z$A|2u4*)OJ2YA&aI zd67@{Zu#yME$#ri1>&xtJ*a4|Njv2-b4!ld@VCWsb_gb#o&g{s5xN0s`Un0Y67OZZ zhEv)L{fOIbX5B{9xs4XK5iCIbtzTPlUE6AD&`0lv)_T6&DycbwnTzUJbtB_&!x#0NY&o89Z zwTnt?Dl(piT{BCEy3?eMksh9yWjXboZab2*+l9j2v{WIPR8(a#^ITC=>cxx7mxW+! zEeXmMz7@AFIVqvVUI2t|G#HY4Za+D(V_ENa~~%d|)1 zrD8iFinyL#vZIZ8!*;x=kbj~(hETwZ3+{kXd>&souZX&jp3`2F8$+&)&A)NU1-i#pwCC<@jA}&vnuMg+3+| zmzL~>GCJJPvvWV<3)AQOZDS!R(Sq8;Nc>8Ub3&)JkU;-r^K`-sG(!Nma@mdVY zgcFPRFMdU~jgGV>Uz8OkLrrF`y?)f9Mo`fMol!$?p8d;UQB!%XGhB47KZ1lex^>ni z8E7o@gYk>Hr@B$`?7j0mwtVp}S8h;Y@cwFCv0+=MP^*0Pbe3*P_!kUA{~(r={&kA~ zhl{5{-(SGBpZcCr?|UbIw);8v(@9ld>FZ30e2?2(bjhSwI(_d}S^AUR9O}_<+%{{;ygK>!3m z00ck)1V8`;KmY_l00cl_a0&R&|10V56ZC@@2!H?xfB*=900@8p2!H?xfB*=9z~B)W zR)!N3`2PRkDHbCF0T2KI5C8!X009sH0T2KI5CDNr0@(j|62UnLfB*=900@8p2!H?x zfB*=900;~o0et^|@N|k1fdB}A00@8p2!H?xfB*=900@9UCjspLJBi>N1V8`;KmY_l w00ck)1V8`;KmY^=j{u(kA3U96L?8eHAOHd&00JNY0w4eaAOHd&&`IEb0dP~Nf&c&j diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql new file mode 100644 index 00000000000..a9360e9d65c --- /dev/null +++ b/management/server/testdata/store_policy_migrate.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_policy_migrate.sqlite b/management/server/testdata/store_policy_migrate.sqlite deleted file mode 100644 index 0c1a491a68d58e019b5256b60338ab8b5f5e40d4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5O>7%SmdDwmB~r2}#_=Q_CDDXd;@CtanawYY9vEI)q7#j6*^w+b9wTTr$s*Mj z*-dvhWov>gKsyO$2UyGwHpc~KPji^V%wmAqoaeOI!(jKaK(343!(8^3!vdMZzW!u4 zKdcW=7MABv%Otz%Rdvzx=C()cwIJFqONK!$w*{EFNHWR^jC>QC=_~={-2}& z?&EcObId)USI+l-)Y~^hi`qYqo5JyLO_^8X-%kDZ#J#DD6XR3y@vp`=!dv6#!pAsi z?5oi4LtoNAdc61Y1b=yk3&~22J<=aEsez)@F;weVY-nnyt&4*Fda677QiT7oz;!QK z*>dfX*k+opDpFIFs#T`zqM>dxMHf3d(?q#8Jon`CdZ~E3lvrI`DSebE?_EtfadRzE z9$s1bT-iBUIr(^FI>HwU+-F5Ysx;Z~0>e||PMdBQuUs#UJXhizwYHov*dt@!JFRo5 zNG;}OZAtnTtJyHzH1?=1Ymym&7*Y75-mG0WDOU^IfT-&(4UR+(fohaA0EnVnhTM`XjY?KpMZ?2bC-(PbI zcY@@JJsinR0<>Zq~$0Iz?b6?!G%)k<{o=yanY?lq_)K`t8R!wbs zYL#`78oMJ`nFY0+P&zH9$<_I?VrrE&nRP;6x?f!MTZlQvoZ;o_ywX&w+oo>4S}lDy z)rPG26z*9|&^BaU@0f*6+f1yKt`;}1-%iX)ZFz3KEGq_UIJT#;IzcWctTd(#$@B>u zyKT4gd)t}A&K~?jOwGqO8L&3Dplfi+6BGO;ew4vkS`!89C03d0H>WlWnv->We9)Zw z_rTfh3`s4(Emgi*8V(p^ajNp*;)Tp{iAXurHko-v`Vj1VpT8@vRDy({CEV~MPGv4b`7gVW+`r|y5ToKE3u)9svZ#AW}3p9fuyRcxkga7O0}Dk zVFs~)R9`)+F78mz$lC>^9vVZFJMz$EyD>wv6lOeBwQW(V)ijDbEZ7TuS2tM89IP^h zCK^u_rUlyJ2(&e|)~OnSk!f|&$!|$Yr!G~Ej@DP6DLb;JDlMjvgSHi5m8;9mUQ_(I zYZMflQl*F1)UT);lF>0U$gk^+1%|m_psFegt5S&YIR}5fOa5loOV-5JjZdBfcwI(vzcL8sKq^aWenLG`lwT;G#-k45;?r@1e9%USj0dhf=A zZmXAZ0N>>wD87(0BGTDrPW=wdjXG^}N>uAh)z}VANd0C@rcM(xns$-L@M7SwiB)+> zi0;no#lQ?u8tgtE8)J51+B+T4%&`88QyDp z?3~l_$M%s(9)vYi&6`=%9L~vUYjR5>x8a=B=MLTABG<^j%XPWu{M{USb^X4ZL>|6B z$&NohOEIGWDZr4G1~o(0EuShJWvq;#H z*+X$~9_mRu>p_YiZK*NJ>wD7oDcUg6 za|7IvpykzPi@OtJatRa=W%3}Z7Cgb=Ku(??o8XscM~yYtghj!97dF?~q0slHWAx0$ zS3>lM2MB-w2!H?xfB*=900@8p2!H?xfWT7-bPL=oo3R_SZ%l69iA7IF$D`4xxvs9% zh3ZmjDZ5aSa;b%Em0egUWHW_@Og5j(q*6k?nyI8y>Fk1#T1cf6LjGbZdoi6^%w?9c zg797{eKD0Xm~Jdq)fR0G?a;1~WK~H;wxq08WtyrKmh378tIFypbE~)JF3t(5)M7rn zm|9HD&ClJC^leHze}1W0EMB@a_vyWl*S`Gu%}{8H+R4Sg4$&VTAOHd&00JNY0w4ea zAOHd&00JNY0!NI%+3@D%e)b>i_5U}a_%}zagJ>=YfB*=900@8p2!H?xfB*=900@A< zQwbao&rZ7c71;Ox$3yY>Q;Q%I1V8`;KmY_l00ck)1V8`;KmY_@Bm!N9i^n!M-=5vr z*oa0ipNvNLa}m9mUS2BY(^nQUOILFX*<3EOa3v#@7FJgB`Q>~ezg)U9AFQ(~c{Zy`y$uFhT_U-@aj44uB5SB)Z>|INW zK5VdH{UuT<`cMJ=y=Oh&``Etz|EEy=PcM=h(H9T^0T2KI5C8!X009sH0T2KI5C8#} zKsX$m#OMEA8h8Z(5C8!X009sH0T2KI5C8!X0D+@NV85ULH}C%+JMp{F^i_Jo0|Y<- z1V8`;KmY_l00ck)1fDknPo~Gt<~CVu;T8Hx<(Zk-+p$k88(Hl3S zCu6Z_wayxOy`IYJ54IK7$Q2s7)PC;h{_&0C&1;)%_3mn?qhyuTqg(4jC9_()`=eVe zefR#|)|K@AJIkG>x}3daebc&-$}Fbyi^5WBuA*vBUs`(NbQ_XE!VK zo1mQ<+hI+$ZT__0Cv(e9wNs-nW~Z-czvcXbt$FlmjedT`2XE^wAY z&X@WW^441$lcbQ%s81pF2l@q{Lb8qhDCEX!uJu9vhg+3K<5qnobM--KRsX29^ii&| zaffZLrmt_^EZk_9O2(aHPa(pE!eTnJG-rcfNT*ZTh5J&gwYwm-YI*Z}`E~ZJnyP9l z{l+W(#IE;Gf%e+$L`NKbfx00@8p2!H?xfB*=900@8p2!H?x90dZeg>&4nJ+Co7oSUSK zfxQ=d{r`(l{NLlhI0_v^8$kdBKmY_l00ck)1V8`;KmY_l-~}X*3N!8?_o9rQ53|XG z+$&+P|9=yTe;xnL3&;;$0s#;J0T2KI5C8!X009sH0T2Lzqe$Rfc$3@D$0f!N^!fkk ze+|WJ^n?cpfB*=900@8p2!H?xfB*=900=zq1iG({O>b_#5{<@AMlatu6^-rZTM_oN zl2s(_3CO^E*R009sH0T2KI5C8!X009sH0T2KI5csYMbhv14^VF&BXslXi^#}Te zR9=6ut*}O}(8x6Q^J@UbW|Mww;NIL|`D|lu-b`-f^}4UxTx#xP^ZGv*|6_>$@Bjf2 z009sH0T2KI5C8!X009sH0T4Li1o&`na&TiH(}>0YH#E)tI&|Wj_`jU^yQ$xvxHoli zVtgt-{?+(K`0vA8;bYvd$G!^vKJ+F1qsMzMPwMwB(|BRtBTYVrD~Pwx@f4|OwmR2M;1l7W<&mN&oR6O zCzsbt#oMLC>e@=_qeOY{rj`>o*AnI7m6gwxoeq_gk2j_xe4)U7Ry3qalMOF0JT>mL z>2~qT^^$e2KWk+}^iWQmqpp_|276@8dwp;Y6{*GCtSw33Vl`WPZW?>kmNm(cRYj~x z28o)_hTLNF<*LT$eNC+Fy25UXWccNpl5Uu`Kuz!p-&$Y2QCz>9xK_Grw=Oxqd~=v|d_UE^Q>NuIZ#E>TEJu ziADGv&vnN=D>V!0rirQ@#UQ)1j;9Arq+iZ93M-W=@0F*MoS5J*@uS9#U5{YB#GYI_ z9^rYO`{J%;F;=4WbRuXER$5SFzB4*X)6}M?G+7s^UpsP@S;)!>rPE@XT%9i~rp{QC zSts>rpjCNO}m8M$VHZ|wfYH72nHe|)8PR|yCnwNFGV-_|oFR@a(THL&T zJ26N5nRD}HSut3{u`G?%Ngw5emBzFonLc4-x9xU*Z!2=x*}iCqA^F%Q1GeNA^vr3h z-<;YkXinDg@j-Lymvem#mF_dA>ys1wS)P+Ui?xexGwLx(PmQXZ1cH^C3CW*CB9_Jd ztmXJ9s=}%#8nieoZ7++nPGLI`of)0B=u42>uC=tt*2FDUH~dy*B{o!1)dOPNOjB4h zkW^JQSF>fSRJ$n|byaHxr26Vnb#aG!Mh+?<_0Sla+>wVS+l?8Tr7+{6s%?u>t)@|I zW5HhNySl+z<_MH2{Lvt(FfGsyN1&~#wNBLt3^%KbPJT;LI(4aPbhN(mOxck&RcSGW ze5$Pgt6W`f_L}0yU8BI;lqx;6rhY};kc^I*L9SS5EHHNc0##K}Se3lGZ^Zlat${CE zL1Cnf#vXKJVkJs#>h_LAO7^^m>E5iN?%0yP<+miA4w+&`RR8&X=d=US5D~Lq`*cdI zy~W5w2AV})&aT?&zZvd&X6?(R$f50uRkh|BYA;hZb_4qLavDsvG6FvC0KUr_p@Tti z-@YpueOj2ZYZt7S$(@M^fB7`mo$-8-H{AW9vzHhdbV|KUU$C_uR4=Q~^*wp_ScE@) zn)`ycoK;V*_ijAswt5){@Lm3a;)^3A!k2C4)bGG_rqecOF}1!_jqT8+*Kf9D>NJ<4 znG$&nFR%@p3zdh&-Ht9bnAntCGM(EiJ;ZbmU`9P+&tOZL@d^I@)1#(5zEcRj_Y;+H zgwM=!UCy&AH@u29wI#`ldw6Wnn%pc$(zYWvk*vtBpq%{t78l`XXStt}t0M>EO1Q6M zLF;jogSO+&U-z2PI_Gr!v3(?x2Vo6W^CrVIF>`X-n%t7eZ8#_OxkER&$Tjlsa$W8@ ze>aC-UBB-pk%#Y3vg41>Qp_ko3NU1)LCuhL%cqJIWyq||SkgDm^F%FQzYmA|^a8$y zc!C;hbefXp%#3TWXB*`kwTCiuO*4!PqcNs@jy~R=`;Yq@-q3 zeK-mc&1E%-mb?_m+yFNuXnFP7;_hshTml6|nLLQ91y3+Ikdx=fCivyqQDeIVOcZ zh<`HuKe1m;|5NPriQiG-KiF$^_qqG(TM_=Iz;z4Wz%@Onwcw!cJII~WQ*;~5oEsd_ zgF^1Wu;Qb+03Q_T3m1<+osIB{H@4V{Or@$fsy3;+ypju^!W% zxD{HiC1^UX>$m4D-h7}X=?^+Il{IIh-V~db2ekb6x!^&w3RX@%G&k*N>TA%(!d_l! zq=DJ+o?&Y7-WyNbg>u68uivs-ILuF)FJDiv4>QN%UWs_HGw+p1=Gh0Fs{laz*6&)yEm z>YGd|u@uw4k=N^a(BZhHh`YI)XKu@`l&+U(FL$}Pv0Pj!B`>@-!CyNw$^|*~+QuOG zkKRyTjqpF9$>_?imY7xo6;q7>QZQ zUQXzs7Gmq4Y+N@t8N})qt8TlB-CK(e8E0Tw?F;V6_Rx{G{F}0FNk@~(^RGN@QsYt) zZJp7HUOWER_N1nMsy|!|%s=deR|j=AAZcsN@q@ul-2+{yc>K;uE4F<0wxwLVLi_pU zpkkeK9YC%A-P3uxC}D0GI{v|(PMZ4^{|}p|cH5u9wIBGJ(eP_0-`Vw?Yw-zQER1v} zM1IKSE!r{}mQKIB#a;NbT^t&6;-G8j+fnmJj-z!w=l+=pfBrnz4ck-t{%u!x3gqp> zo(qnzfqWB_HSg+ePfIi{GB+NR-HCT1{QF6+oAd<%zhex90)NKYLFW>5slJkKSn!u8 zH|TejkHd2jKAGg6I2RHF-Z@zDnP7i|nL(TI8^Mt2gRcROKRJ9f2=uM*DxL2&o(JlPu>`V_Fyd7*yAkWpNH_`X!(n7-AL+H2@#KGX5wD14_Iu!rf zUCHAW1V8`;KmY_l00ck)1V8`;KmY_l;HVHd8{VAU&zA<6-~Wf-|2ryLMLR(N1V8`; zKmY_l00ck)1V8`;K;Uo$@cVy<qOK>!3m00ck)1V8`;KmY_l00cnba0IaaKO8rd z0s#;J0T2KI5C8!X009sH0T2LzqeQ^G|DTJ08=^lvKmY_l00ck)1V8`;KmY_l00ck) z1YSG>VJ;ks;q(75o?_7>5C8!X009sH0T2KI5C8!X009u_CxG>TKM@>*00@8p2!H?x zfB*=900@8p2!OzgM*yGyfAMsR9)SP|fB*=900@8p2!H?xfB*=9KtBPj|NDvH7z987 z1V8`;KmY_l00ck)1V8`;UOWQ0|Nq6)DS89~AOHd&00JNY0w4eaAOHd&00R94{ttLC B#c2Ql diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql new file mode 100644 index 00000000000..100a6470f43 --- /dev/null +++ b/management/server/testdata/store_with_expired_peers.sql @@ -0,0 +1,35 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.sqlite b/management/server/testdata/store_with_expired_peers.sqlite deleted file mode 100644 index ed1133211d28b5b2faf4963e833b2ac521ff7fe0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5Pi!06eaAVXB}%esuCv+9D!WQ(InJ7}C2{yi6df2_N}{xi^~b9}>$Ml@a7aE; zW05oT%utrRNCB?1Mi8J#i{@CMIrY#(QUqwW*PMzTg7#7%*P=c2(%T*iY!3a+U(WE4 zCDoh4t9)sh!+G!b-n{qu{oe2Q-pp{g{r;M4u=t*;H6$aR4!suQxX_=(nl&vQR27*e^;`WNV*8gtsTQ&?Cl4m?);2DP>nH`tMJ%R8*I zt4IyzX6;M*KC9X=+%$I7lr_naRYj~y28o*QhTLGcN)?UK=c-sfc7@#($?(h7CEYO5 zQWLzw_qJBo3tJE3cZ(0~)+KH&ZfX^5#}?b!EdT z;IWHu6_<-!#f`<{cHHWkPGqUGiNxL52ruRbnMi+}qR>D6wQ@AVXL+tY>{*c60^Kwp zwWSzjGuHm)UK8n-vyH+^rIvQemlDtJT#WEM&pm%&nSmu@y3(t1Zz1LwbB33z^KxCS9GJTGYPIxT zS8KB3Q@Cd>LEDgZy=4|QZ8N@9TrTXc?Zl^~raX14Br681Iku;Fnv=DEvDcis z_rMt{&ZES>phv#I^V(p>_jNp*;_@KH;AXurHkoa*VVj1U88;*~nDy(`&gT`s4 zl}<2D>kzg#(wWg}ioOK7=^9pp%u?J}b;ECfR$@&RRXrefz%+%`14$KCbB&;6m1@={ z!wg~psjhldT|A_ok+%y-Ju&(wx8%Obc4PWxDa?4HY6qfJt!fl^Sg;rRv2L)2Iap;1 zO*Ed$ObfKb5ol^^wN)_!Bh%`lli!e()}B-`T3T0mrX0$esx+8F4%$|LRc=qNcbekI zU8SH{m&zTqrhY};kc^g@L4I9lEHKRd0u@zJScO7_&pG(>EgvCT5oVx_#vZq1VkJsV z>h_^TO7-3t-^>Q(MJZ#k=uT<6np z&~0@x&fvTJ1H~6|21GjB%&FgjxlyZWPKl~rsTw<^38~*~$<%3LM$<0x7+wtQH?b=9 z3DGTGsxh%HH)J}tSGtesp23Vd#9qRd(!(SC%GE(3({~7=^LeBkj_~P8uFZK?dpF4Dei(DoDF7L@5=kMmw*7f^t5_$OU zBs>22EX9lhqyR%!YSau_w|uHdQHIRQj3r$&LQmB4^`~&SFD~F~h$pDATB|N;&LZJZ zW>3W4d8jAttOqH6G^N@gukT4eqG<1w=#35iq$+huZUmflKuW6D)hB}x(X>~SXgNxO z%nfjTf|ggGE$&W?$t6%gl*xmrTJQk91357}G{P@V4jOB&35$aJDQvE@L!qx_*oo-x}M!A053iHat2yYObrx zdqQPCIiHy>OWEXfrowX5`Aj-LozBc<)5)Z;S4o#s$y8=qNKPkHaUpj*nY^7&&16%# zY-;{yGDZ6a(~X&m+Mp9dhji9RvZ|ya`%=cLGDTGi^LCYjRb}PlsnvT^x2J?;a%L_w zlblIT-I`jL^aDzpot-Zf3U}^I{qUn7gqM^5`18;>wUdkeE<}HLfdB}A00@8p2!H?x zfB*=900@8p2z&^z=enC{8ae&CM;&<>wZQ3kyqZPiobT zI6VzXWAw}pdhi6@`!7r<=ck2yJe9ninY*2$oAr~~e0pv^m9lUDPo+(f{IoDXSmfkd zlJsSRY3nae-!QQLZd&gj{lLEd|NBtv_vcBC=nDvd00@8p2!H?xfB*=900@8p2!Mb~ zARLa4;rstC4QxRG1V8`;KmY_l00ck)1V8`;K;WxK;I!ZWH}C%skL#h!-@a6&4PGDs z0w4eaAOHd&00JQJWfOR|8NQL-Wzn0H^n;m-a@Lsfw;qJlmaz%TTnP0rWbFi^no~^30N;V}6 zYuW01nWg=e*3Hbr*1FfMTs)b-O^+PU`ofXfj(@=;z#o{exQSQC7$_lc&+>R(<<8b@)_WNxgTxzOb{q z+W2r&{$PJ`L6#S!?Bk{E(bM^Gre#X;iIzN-mOgX|g$=hdSJJL(6IvO7^gN?JL!y|QuBkju089^PEfRn{{1jh&~v+D`MpJERaNqsz>b>CMaw zxm-R!m#1G3Fg*&1^!s%I-_SL$|A%A04#oZ}_Aj)-3j{y_1V8`;KmY_l00ck)1V8`; zK;V@j@Mbv6^*c8jwljK9u?_Lz>{zAG85(>2uY_X%8vC1=^2&4&4FLfV009sH0T2KI z5C8!X009sHfv+%uWSDVhxi=|n7C6hTFGI6oHg=XfD(v;E15 zc_B}a&oCd4AvNXpn;f<6jZWM4=!pc?xm@Z+Px?4b+gdu)Uao28_5Y#RFGF-zs`&N6uGf4cwwkAxw*2?8Jh z0w4eaAOHd&00JNY0w4eauR4K0$o>DJ*IxB*qkSL%0w4eaAOHd&00JNY0w4eaAOHd{ zL%_WLkM;k{5JXuJ009sH0T2KI5C8!X009sH0T6ig31I#I>gxzC1OX5L0T2KI5C8!X z009sH0T2Lzmn9H2E5QB#FN+c7K>!3m00ck)1V8`;KmY_l00cnbl_r4e|F5)q&^QnP z0T2KI5C8!X009sH0T2KI5cpCEL}ULGy3G9|bm{ZhKVABZ@n2o~X#Dob@OW(av*GRV zUxoL>7r0*xeHQvn=x6kgUN^rs!oNGgg=D45j`YWM*d$rY%)z(6{#*tl?u~!(NGVVqKoDa8H#e%?$RIj82wvtWpS%m*eS+W zHx% zETo$zspO`pvHVhbd;v4bx&!sE>gb^b@D zgVh|%(%2s9qZGH&nARuLCu|%y-OitEMGiaL7xghDAKOI0mfV7#IgNLlQOX-*bYQzMyo0M66B_9Ee*0YabML9zg1a@HC0sgfY<@k6jl!;RaDK@Y{@Fs ztV_n8sx<;qUG=ECct|}X2NjTdV)RXJ$$gXU#`MinnDIo_4n(P1)hM>HU@!Dz-CzxK z1j-ctXpoed7HEef(A3mwt6~I(o7F`pzac5DJ*i@}w65|@Ig~Y3X)uL+s;vO4+@4(T zG{uj*N`bd7l{;un{ffFF87(t|T(QnrVC?z@DypKe3VC(ki1+7P17Ea)!ax~~J#NXw zN|c(^?L&!_?0FB6BP=pOJ?QG>g2P zUA5DHGu(B|+Lud_Lpv5LYSlB;PNr-e2lVUZ)R<~z1bo~Xe3vysdxPMqeOF@iMPbUW zU9dJ|ha(aG-K$)C!t+7iaQBDKPGY3jDRnY^!Pa(Aovbd`_w2hDBK*~>+;iS?Rvo#{ zr{SR6>SUb3clig3FOCcdU$&W3zXQ{mR@0otRJ&3&c1V+6zuA(h(_D&XO5`!Tz}9ar zRO%CVTe?(ZVqI>?bZoD5AJaX98Fh%gge|3qNBEVigQh&bLkOMEBjs>}Pfv1f&a)~v zyoyz|A<2rnd!g5w+$=}Zwj(!@tjMmQl=$=>7vU!-xu1}$BM0J2xZ9zi^|;AF+i~Zw zC(US`V>H;1;a-*=P9 z!*?gy@yBN=W)vU=7_w5MX2`nbQ$>n0WL9P@>6+$wqL!~eg~NSu0bfHrL5d#6NiZ0ILdsY`Mr;H(2uQoXJ|8H9-DvYJFo zUJ7Jxfa?>qy!vc$ceYC|fdZmT9z@lG2k0HhiP@nMesOZpSaVHS6x>f?T4|4(tL^CM zZ$o_Oz0gHr{4d8Ijolgjx5%H5{&i$>xHR&g7ycuB`NFS<{xt+~JnlbV(rJvVN5>DlyIo2=V5})x75NlRMSQc~KGtKp6SqRkwFFJa zb^Z37#hVW_B>i!Vrn2U2)SF_{@_?5AJ{R0;R>8`thvrE;n)(`aV&NpOG|<3ocrP)v zxcSzL&O$oX3-~TQESn{NOX4Osz3GZUOVzt~vCvmR^n0($VLOs0c-V zD{OksMORxI>DjHdELUlbW|a!8gec+~YDM)H=1o;Iyh7%Vs7XrAtLJ0~Wc5uZl~{`D zK9SeydC=jwrHH$EHqSgQyHs2&(z)El!uDceshG&UIl|w)KF9?*_1eZD`1jsYz7gTS zN0ZSjLCZ0l*ct2Y@h1kDOicem=SMb!%e!?4V;TB;_7~?_(rpB_}zdvs#F) zf3k7iJjozd_F3h?RqV-HbjUaZ%W7ZnjBE!TY0H07)-CC1GBNwcizYQL717oi9q7%A ze`Zf=_FnG}7d`V2JK-C>I_r_NHRkxi;7Q#xU8wlvJ6EjO@}p~(a_tK3_pb#N>m2J0 zYW1Hzy+s!#%oB!=e{iRh=6Q<$hm)sv+h4-9pZS_m|7$0IuuOzs3!WRVmjxi7l{2A}`I+vhJ^_6tPg1-26AA8_b0IO{or4u$5B4{h8MFz%5%ifp_!{8&lmB}(2=uM%DxL2*FOCV{yEeja zCI$|_02DJa_z(Ks_Qk6f!uI=?;n?eKcljOM`%VyHFw;bM(yHF|y91Z;iS|%D!oTwl z*Iu+*BP(<|s!o?N$W!Zqoom6i1oB*MdJ}zrE-fU?a|kVWg4i3p6S)81J%fQQ2!H?x zfB*=900@8p2!H?xfB*=bUjlgk-}%)odIka@00JNY0w4eaAOHd&00JNY0xkhu|944X z3j!bj0w4eaAOHd&00JNY0w4ea=a&H1|L0e?=otur00@8p2!H?xfB*=900@8p2)G1r z|G!HDTMz&N5C8!X009sH0T2KI5C8!XIKKpN{r~*x7Ci$25C8!X009sH0T2KI5C8!X z00EZ(*8eUEY(W47KmY_l00ck)1V8`;KmY_l;QSK6{r~4zx9AxNfB*=900@8p2!H?x zfB*=900_7QaQ)vUfh`Dt00@8p2!H?xfB*=900@8p2%KL6SpT12-J)k800JNY0w4ea zAOHd&00JNY0wCZL!2SO&32Z?C1V8`;KmY_l00ck)1V8`;K;Zlm!1e$0t6TI81V8`; zKmY_l00ck)1V8`;KmY_>0$BgMB(Mbm5C8!X009sH0T2KI5C8!X0D<#Mz`XyTi+vHI zKfFKy1V8`;KmY_l00ck)1V8`;KmY{JAAv9z4n^_(|MRC<^aun%00ck)1V8`;KmY_l z00ck)1iA@e{ohRl`yc=UAOHd&00JNY0w4eaAOHd&aQ+D3`~T-pr|1y~fB*=900@8p z2!H?xfB*=900?vw!1}+N2=+k$1V8`;KmY_l00ck)1V8`;K;Zll!2SQ{Pp9Y+2!H?x WfB*=900@8p2!H?xfB*<|6Zl`7K~6CM diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql new file mode 100644 index 00000000000..69194d62391 --- /dev/null +++ b/management/server/testdata/storev1.sql @@ -0,0 +1,39 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); +CREATE TABLE `routes` (`id` text,`account_id` text,`network` text,`domains` text,`keep_route` numeric,`net_id` text,`description` text,`peer` text,`peer_groups` text,`network_type` integer,`masquerade` numeric,`metric` integer,`enabled` numeric,`groups` text,`access_control_groups` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_routes_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `name_server_groups` (`id` text,`account_id` text,`name` text,`description` text,`name_servers` text,`groups` text,`primary` numeric,`domains` text,`enabled` numeric,`search_domains_enabled` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_name_server_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY KEY (`id`)); +CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); +CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); +CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); +CREATE INDEX `idx_peers_key` ON `peers`(`key`); +CREATE INDEX `idx_peers_account_id` ON `peers`(`account_id`); +CREATE INDEX `idx_users_account_id` ON `users`(`account_id`); +CREATE INDEX `idx_personal_access_tokens_user_id` ON `personal_access_tokens`(`user_id`); +CREATE INDEX `idx_groups_account_id` ON `groups`(`account_id`); +CREATE INDEX `idx_policies_account_id` ON `policies`(`account_id`); +CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); +CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); +CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); +CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); + +INSERT INTO accounts VALUES('auth0|61bf82ddeab084006aa1bccd','','2024-10-02 17:00:54.181873+02:00','','',0,'a443c07a-5765-4a78-97fc-390d9c1d0e49','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO accounts VALUES('google-oauth2|103201118415301331038','','2024-10-02 17:00:54.225803+02:00','','',0,'b6d0b152-364e-40c1-a8a1-fa7bcac2267f','{"IP":"100.64.0.0","Mask":"/8AAAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','One-off key','one-off','2021-12-24 16:09:45.926075752+01:00','2022-01-23 16:09:45.926075752+01:00','2021-12-24 16:09:45.926075752+01:00',0,1,'2021-12-24 16:12:45.763424077+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); +INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO installations VALUES(1,''); + diff --git a/management/server/testdata/storev1.sqlite b/management/server/testdata/storev1.sqlite deleted file mode 100644 index 9a376698e4d226fc08fa68c12fb9bb4cf50375cd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 163840 zcmeI5U2GfKb;miPBucg=uCv)poJ29DU9U~Zv^o5y17WKV+N-Q3c_lkpV-XC8nW4`*J9@$Z(o!v!l_ zZtRO4rs=98wM402XSyyL>Mm1sv8yvplpBL{FD`FZ7VlRQtLrP3_Y&3PtEnb7))Upi zl~u1*oyMxkr(4r8zFg+Mzi3Fc78_h(aBAG?)BVMzyOp7BC9Y9#s|kbc8`r&Nol`|> zGdF8T(sx+HhT*2M{f?|jhO8=LLo!Izd^hAayI!qpj6OHS+JP(VrbtFmt|jS)iI%$H z6~4E*y0*CaF!4_1q20UW_2rHAt^1pctLyg@)$LtN7rNMyL_-&w)x_RxFITn_*3fh!OM^`&--(a&VtJT}48|!6gVSHBO~v>;&mE3=7G(B7H_b=w zDhAn%bw1s1BE539QCO+eQ@?yQ`E22QjOTgof|4 zJ;pe#CTwq{Go#xP0||1+HLNz7rMRQ&M$iDQ#HK2$dPr=SX$otFlIp7F8bQ@6)oDqF z8N@i zXg<}L7V3v1(9zUJw{C=HrZq$-zbz@C-rc2pZ6Vvg&C{)=^X7&+5Wx5y#dRXgK1 zW2bM{fn16l+JRVC8=j&1nX+*Z(yy1(WU7@Ba)+mIbJh&)kEkbia}!6oIlFcWqcn>q z_r_!VTeIBZCC>+W(>)kE{lr+mQ}Qzd!Pa(AepZj`d-mpNjGvw5KIScF<;(e>#=>sP z&p3te3Qm+j$QcspY%`~Eht@{jj=3ai^rUKRj~1jsvn5lfg&8fo$YXdhaL~f4Iv_-M zb*ah3mfV)9ZLjnI(>;Y5`NW>Xma=2x{GHihAv4f~;C~*kMPqz+hCAdut8yc;*ihS& zthlG6{nq4WIg+*=xrt;&b_LbsNB6iGKQqJqkX#)(5Ld!Ij)bkpO%B_RyBI!hdF-_5 z_+$G_BoD%xs^+b%X$|M(bTqjwk=t+@4Y)%$xX2Ci@A9_nJAXHaj;`N#lgPvOCfV^P zU@2x4AO#q*(xh(4y5&h@!nyqCYkalB%~PxgB!WAt|ZVQlAV%M9W@HqU|UJGB>~t z2wGl!wz#`6CYL|~Q6>+fYT*X@Cvx)U$T+_|Gir>MEljzqpVJw@*$)Bg~O zKcqLjKmY_l00ck)1V8`;KmY_l00cnb^FZJ*Ir7@p)Y{b4#Di~7O!+MS3kP%zgO&ch*1o@kS(a_y+gd>VUqQeN2C^Ak?-?nMQ+2wR9<$P8TFe zsMYI@LERKeLaFFk56<%@-UDkl|7sd90E5C8!X009sH0T2KI5CDNQPT*-|G(Nla;PtO=ZEa0VZe5(1$mY|zlKHQAg8Ph% z#Y`o)vYbh!(<_x!E}zS!N{jTobs09Z{jZC(bE7HT-<@vNgPi2z`RVt)1*`q{GP)krmxRoj1NDKL5PAFzec0=!f z=Wsu2A(&ZPTeA^la;1E^kShq;6ClVh-d-xCSC&!*;dUmK%jAVrsUireTscQ0s#Hp) zN_mBCOWl@XBG4H9YyF~3LW!Q9Kl<)*WhpOI%9W+mq9EK(*RNE|7^8ilo_q=D*9uFVpaS<1b}}Vx}l$jv2kM zlv&EBm(r=FY^9XSEv1)Ji@8i8mCxiWD@&zZNmyDr&g=u-&KL8UL0TUsGI;QzS^^?M z#-v;w=GH}PM`e;NO;Gw$dxJP-f@5C8!X009sH0T2KI5C8!X z_#-AT9i8R;D|VyiseknjFHd?`IPCrZuKE4{e~$kozWYZE0bK_H5C8!X z009sH0T2KI5C8!X0D;c~ffIgdfAoZM(YfdtPK})Z3;KW; z2!H?xfB*=900@A<871&+Wb`WAVh^rdroT&j>CzOvBJ}@aD+#)AQo~DYV=1LyC2Lgx72Qfezlu^zx$r^ zcSFsy553af2YNm|{FJ))@OZXt{dTzdlsdfx^C|VTkf0wHxBksr?^JfckB-yNdDC;g z=8ulg%{_}mhbZB4j}i*Ik8162Jr;J_-TJq_ySMb7n&0TwALch6?#(~i(wBrs5ANT6 zoGA+{PfntQ(^PkoXWM-xWd7hv=uv_H5WD%^0{cPc=TO1Z4@a)D)sxJF{nmc@W_hJF zFV~b!d3ANEq`ZA+ZU4^3clW;~6t&XU&b_yra!z`xFP&|r+~V)F0Gaxe_Kv3AJki0-6wZ8=jXS& z+Wt2W*0P;4X&j8SOOmf0=)3@*c zFGS)>d?Eg?@xM8v!4AU$0T2KI5C8!X009sH0T2KI5CDPSpFldwxYJzC95oA^=E_^t zEO45OMI$$(Z1OZ$H?aTz{nd|>AOHd&00JNY0w4eaAOHd&00Ms^1kTLc|3_YV@qHw{S00@8p2!H?xfB*=900@8p2!OzgCV>0@ zFPbpw2LTWO0T2KI5C8!X009sH0T2Lzmrnru|Cg@=^Z*1v00ck)1V8`;KmY_l00ck) z1YR@&Jpccq38Q`x009sH0T2KI5C8!X009sH0T6ik1aSZVxoLM`@bc@{}g)B8_WLt}_c+HKBCdOq1)^tBR>J)?!wJf%Krb7}UQx<(=8?YP8l; z>$|4ryjm@7w$!Gq1k~x-LRj;%u6NDCrsXA8Dz_IO+`XTelREO;^{T8Gtm#;m#0GZlby(P(tn>5z=F}_a`WPzR zV@`J`$N8%~Cwmra7u{tvVv;^LtZotrS866Ce-MjV7WbpJqcn0SwnR4+mh1Vmg+`V>nYEaJy}zgHdDx_+6u7BZObjcD?!{13cM|; z=A$+BE9!=1bj=KM#X4i5xf>LytBS(vjJ_asuX z=RHjKW=(a+mh_#VCFyj^6f>d*+Yhv92ciKY=D1Gilvrnnk%tU*i@cm&wKIM*-1%l5 z$fd}k9f)qq5a~gMOIn(W!tC&Vls>b$c(Hk^dGId%@(MpLth8Nfdt%a%s;%--$noMlT zZJFBkN)IsIQ<#xY>^W>HJ2uYWnH{#|2{a-2pT}#_7@wWt4mr=N-0&(k)V3rm?&)a1 zHMv=iq-{rTB3Y4LK{ff&Jub%2%y2&>S4R%Sm2i(EVe4^|!?xqDUyobSI&C`s*gg}< zgRrKmd5d9Mm^nEeO>Rr%Hk?KS?$8Y`a)bQ4ye<3A-_4<;>-XIx^6Mcobhn#gtN@}&#C&LiYT2_;2%S(aG4R8a3mRFxG z?yh#pB~U<=$%Cj`xPktOoV+!;h(7Ruk5|L z`+oSP%Q1dK;10{)#5Fyrwc()QJII~pDZ2G%&h<~|ej#^aSn<)^fcJ~^go~#i&ct{@ z;6A!yP4<5I{**-60rU%mU4cI+Z>|M)$_=L1HMwJMi-Xw?7;k#ofLsa34FsmY?H@ex zLAEPt^&Q8L=>2BaZ#2EzXxBFE1&F!z^A*>(tqu(O=-trr*T?y93&Sk4pRFkLr_FAE zEyfowafg$hanXe)!(2CdTXZ|&a=#JvN!oG2PVO<5XIH)w<1by}KDLd;ma*Sw!baf~ zPXw=*_?>a}==fpxv`2}%jCDk-qJW~QNMO}FzcUwh^ItvcDx_1rU?0E! z!KE0#c!fKh@rN(y4S!0t)C6;1(FXZ_n4?7&U6md7N z=9!meS1NZabS-yzacg;TrIIXud7OXe$}kt?)N31q5Ip;;@}(I64O)y|3|o%bMSrfl z?Oz&VGGXa}y--Gv`+0idi`K&Q*{*LaB;~fi4=@t5lH;7vX)VOoKiRl$UStsKJFLFz zD)x9SI%J%QW%Vz7Mb<}0+6rEjbxS&$Ox}F$s6~xSMYMHB4So6ipWBO??N@ulMc?|v zPWV#4&iW*6jX8cWd{OsQH!7aKanXt`-+$dwu3e%1{)Mn&owiP)*5K9C>vU7XykO|~ z2X{GXUZ?mBE}q(be-77v>U&0m@14A``#JB#$9b_l)R_?Z0k^m4lF6X-S0WLv_y@Z= zG@#+IYv{S67K|K!T=#PxU5W8GZ*qrGdr9AW+0|VFdDmgDh38jazKO|Nclnp6C0Z7l z7ao&`<8Q?HJ4x;^83+PF#~2C)!Hlc@&L!+p10~(C5G+q_Fz70uM(1LDGRZx2ZX|}h zbGYIw;qitu!!{8#f&t5iKnEOu^7}`F(AawJ(goT%IwpMM^>Kb9IduAkpqP;%*yx*w z=VvX1hwoa3WAC@!?RWU>jWEJ+rit*lUA^mfhi>DOha-s?fAuPNxNP-CR_Jn6i*93( zr`AI!uZQ~*%5$~pE%bxAw2?5cA#~jZVt?>X;`x8~3I>iK00JNY0w4eaAOHd&00JNY z0w8d93E=nt&aQ4TG7ta(5C8!X009sH0T2KI5C8!Xa0%f4ze@r~5C8!X009sH0T2KI z5C8!X009s{{QUi79#@z5C8!X009sH0T2KI5C8!X00EZ( zp8t1A;0OXB00JNY0w4eaAOHd&00JNY0%w;1?*E@%-C|@Q00JNY0w4eaAOHd&00JNY z0wCZL!2aJQfg=cj00@8p2!H?xfB*=900@8p2%KF4=JWqt{C5%hhZhKd00@8p2!H?x zfB*=900@8p2!O!ZBM{}Hktuxt|LiFiBLV>s009sH0T2KI5C8!X009sHfnEaG|MwEX zIS7CN2!H?xfB*=900@8p2!H?xoIL{g{{Pw2DMkbWAOHd&00JNY0w4eaAOHd&00O-P zu>bERf^!f60T2KI5C8!X009sH0T2KI5IB1T@cjST( Date: Thu, 10 Oct 2024 14:14:56 +0200 Subject: [PATCH 36/81] Add billing user role (#2714) --- management/server/user.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/management/server/user.go b/management/server/user.go index 38a8ac0c401..71608ef20e1 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -19,10 +19,11 @@ import ( ) const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" UserStatusActive UserStatus = "active" UserStatusDisabled UserStatus = "disabled" @@ -41,6 +42,8 @@ func StrRoleToUserRole(strRole string) UserRole { return UserRoleAdmin case "user": return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin default: return UserRoleUnknown } From 09bdd271f10fa80f42424ffdb14deb1db60e55a9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:54:34 +0200 Subject: [PATCH 37/81] [client] Improve route acl (#2705) - Update nftables library to v0.2.0 - Mark traffic that was originally destined for local and applies the input rules in the forward chain if said traffic was redirected (e.g. by Docker) - Add nft rules to internal map only if flush was successful - Improve error message if handle is 0 (= not found or hasn't been refreshed) - Add debug logging when route rules are added - Replace nftables userdata (rule ID) with a rule hash --- client/firewall/iptables/acl_linux.go | 57 +++++++- client/firewall/iptables/manager_linux.go | 2 +- client/firewall/iptables/router_linux.go | 15 ++- client/firewall/manager/firewall.go | 6 +- client/firewall/nftables/acl_linux.go | 124 +++++++++++++++--- client/firewall/nftables/router_linux.go | 35 +++-- client/firewall/nftables/router_linux_test.go | 8 +- client/internal/acl/id/id.go | 41 +++++- go.mod | 18 +-- go.sum | 32 ++--- util/net/net.go | 3 +- 11 files changed, 267 insertions(+), 74 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c6a96a876cd..c271e592dce 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -21,13 +22,19 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type entry struct { + spec []string + position int +} + type aclManager struct { iptablesClient *iptables.IPTables wgIface iFaceMapper routingFwChainName string - entries map[string][][]string - ipsetStore *ipsetStore + entries map[string][][]string + optionalEntries map[string][]entry + ipsetStore *ipsetStore } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -36,8 +43,9 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi wgIface: wgIface, routingFwChainName: routingFwChainName, - entries: make(map[string][][]string), - ipsetStore: newIpsetStore(), + entries: make(map[string][][]string), + optionalEntries: make(map[string][]entry), + ipsetStore: newIpsetStore(), } err := ipset.Init() @@ -46,6 +54,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi } m.seedInitialEntries() + m.seedInitialOptionalEntries() err = m.cleanChains() if err != nil { @@ -232,6 +241,19 @@ func (m *aclManager) cleanChains() error { } } + ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + if ok { + for _, rule := range m.entries["PREROUTING"] { + err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) + if err != nil { + log.Errorf("failed to delete rule: %v, %s", rule, err) + } + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -267,6 +289,17 @@ func (m *aclManager) createDefaultChains() error { } } + for chainName, entries := range m.optionalEntries { + for _, entry := range entries { + if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil { + log.Errorf("failed to insert optional entry %v: %v", entry.spec, err) + continue + } + m.entries[chainName] = append(m.entries[chainName], entry.spec) + } + } + clear(m.optionalEntries) + return nil } @@ -295,6 +328,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } +func (m *aclManager) seedInitialOptionalEntries() { + m.optionalEntries["FORWARD"] = []entry{ + { + spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, + position: 2, + }, + } + + m.optionalEntries["PREROUTING"] = []entry{ + { + spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, + position: 1, + }, + } +} + func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 6fefd58e67e..94bd2fccfe1 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -78,7 +78,7 @@ func (m *Manager) AddPeerFiltering( } func (m *Manager) AddRouteFiltering( - sources [] netip.Prefix, + sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 737b207854b..e60c352d5c1 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -305,10 +305,7 @@ func (r *router) cleanUpDefaultForwardRules() error { log.Debug("flushing routing related tables") for _, chain := range []string{chainRTFWD, chainRTNAT} { - table := tableFilter - if chain == chainRTNAT { - table = tableNat - } + table := r.getTableForChain(chain) ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { @@ -329,15 +326,19 @@ func (r *router) cleanUpDefaultForwardRules() error { func (r *router) createContainers() error { for _, chain := range []string{chainRTFWD, chainRTNAT} { if err := r.createAndSetupChain(chain); err != nil { - return fmt.Errorf("create chain %s: %v", chain, err) + return fmt.Errorf("create chain %s: %w", chain, err) } } if err := r.insertEstablishedRule(chainRTFWD); err != nil { - return fmt.Errorf("insert established rule: %v", err) + return fmt.Errorf("insert established rule: %w", err) + } + + if err := r.addJumpRules(); err != nil { + return fmt.Errorf("add jump rules: %w", err) } - return r.addJumpRules() + return nil } func (r *router) createAndSetupChain(chain string) error { diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index a6185d3708e..556bda0d6b1 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -132,7 +132,7 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error { // GenerateSetName generates a unique name for an ipset based on the given sources. func GenerateSetName(sources []netip.Prefix) string { // sort for consistent naming - sortPrefixes(sources) + SortPrefixes(sources) var sourcesStr strings.Builder for _, src := range sources { @@ -170,9 +170,9 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { return merged } -// sortPrefixes sorts the given slice of netip.Prefix in place. +// SortPrefixes sorts the given slice of netip.Prefix in place. // It sorts first by IP address, then by prefix length (most specific to least specific). -func sortPrefixes(prefixes []netip.Prefix) { +func SortPrefixes(prefixes []netip.Prefix) { sort.Slice(prefixes, func(i, j int) bool { addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) if addrCmp != 0 { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index eaf7fb6a023..61434f03518 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -11,12 +11,14 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -29,6 +31,7 @@ const ( chainNameInputFilter = "netbird-acl-input-filter" chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" + chainNamePrerouting = "netbird-rt-prerouting" allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) @@ -40,15 +43,14 @@ var ( ) type AclManager struct { - rConn *nftables.Conn - sConn *nftables.Conn - wgIface iFaceMapper - routeingFwChainName string + rConn *nftables.Conn + sConn *nftables.Conn + wgIface iFaceMapper + routingFwChainName string workTable *nftables.Table chainInputRules *nftables.Chain chainOutputRules *nftables.Chain - chainFwFilter *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -61,7 +63,7 @@ type iFaceMapper interface { IsUserspaceBind() bool } -func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { +func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) // and is permanent. Using same connection for both type of operations @@ -72,11 +74,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa } m := &AclManager{ - rConn: &nftables.Conn{}, - sConn: sConn, - wgIface: wgIface, - workTable: table, - routeingFwChainName: routeingFwChainName, + rConn: &nftables.Conn{}, + sConn: sConn, + wgIface: wgIface, + workTable: table, + routingFwChainName: routingFwChainName, ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), @@ -462,9 +464,9 @@ func (m *AclManager) createDefaultChains() (err error) { } // netbird-acl-forward-filter - m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to netbird-rt-fwd - m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) + m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd + m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) err = m.rConn.Flush() if err != nil { @@ -472,10 +474,96 @@ func (m *AclManager) createDefaultChains() (err error) { return fmt.Errorf(flushError, err) } + if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { + log.Errorf("failed to allow redirected traffic: %s", err) + } + return nil } -func (m *AclManager) addJumpRulesToRtForward() { +// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) +// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the +// netbird peer IP. +func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { + preroutingChain := m.rConn.AddChain(&nftables.Chain{ + Name: chainNamePrerouting, + Table: m.workTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityMangle, + }) + + m.addPreroutingRule(preroutingChain) + + m.addFwmarkToForward(chainFwFilter) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) { + m.rConn.AddRule(&nftables.Rule{ + Table: m.workTable, + Chain: preroutingChain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(m.wgIface.Name()), + }, + &expr.Fib{ + Register: 1, + ResultADDRTYPE: true, + FlagDADDR: true, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + SourceRegister: true, + }, + }, + }) +} + +func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { + m.rConn.InsertRule(&nftables.Rule{ + Table: m.workTable, + Chain: chainFwFilter, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyMARK, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), + }, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: m.chainInputRules.Name, + }, + }, + }) +} + +func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) { expressions := []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ @@ -485,13 +573,13 @@ func (m *AclManager) addJumpRulesToRtForward() { }, &expr.Verdict{ Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, + Chain: m.routingFwChainName, }, } _ = m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, - Chain: m.chainFwFilter, + Chain: chainFwFilter, Exprs: expressions, }) } @@ -509,7 +597,7 @@ func (m *AclManager) createChain(name string) *nftables.Chain { return chain } -func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { +func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { polAccept := nftables.ChainPolicyAccept chain := &nftables.Chain{ Name: name, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index aa61e18585f..9b8fdbda53d 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" @@ -24,7 +25,7 @@ import ( const ( chainNameRoutingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" + chainNameRoutingNat = "netbird-rt-postrouting" chainNameForward = "FORWARD" userDataAcceptForwardRuleIif = "frwacceptiif" @@ -149,7 +150,6 @@ func (r *router) loadFilterTable() (*nftables.Table, error) { } func (r *router) createContainers() error { - r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingFw, Table: r.workTable, @@ -157,25 +157,26 @@ func (r *router) createContainers() error { insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + prio := *nftables.ChainPriorityNATSource - 1 + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ Name: chainNameRoutingNat, Table: r.workTable, Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, + Priority: &prio, Type: nftables.ChainTypeNAT, }) r.acceptForwardRules() - err := r.refreshRulesMap() - if err != nil { + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.conn.Flush() - if err != nil { + if err := r.conn.Flush(); err != nil { return fmt.Errorf("nftables: unable to initialize table: %v", err) } + return nil } @@ -188,6 +189,7 @@ func (r *router) AddRouteFiltering( dPort *firewall.Port, action firewall.Action, ) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) if _, ok := r.rules[string(ruleKey)]; ok { return ruleKey, nil @@ -248,9 +250,18 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - r.rules[string(ruleKey)] = r.conn.AddRule(rule) + rule = r.conn.AddRule(rule) - return ruleKey, r.conn.Flush() + log.Tracef("Adding route rule %s", spew.Sdump(rule)) + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf(flushError, err) + } + + r.rules[string(ruleKey)] = rule + + log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) + + return ruleKey, nil } func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { @@ -288,6 +299,10 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { return nil } + if nftRule.Handle == 0 { + return fmt.Errorf("route rule %s has no handle", ruleKey) + } + setName := r.findSetNameInRule(nftRule) if err := r.deleteNftRule(nftRule, ruleKey); err != nil { @@ -658,7 +673,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) } - log.Debugf("nftables: removed rules for %s", pair.Destination) + log.Debugf("nftables: removed nat rules for %s", pair.Destination) return nil } diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index bbf92f3beaf..25b7587ac67 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -314,6 +314,10 @@ func TestRouter_AddRouteFiltering(t *testing.T) { ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) require.NoError(t, err, "AddRouteFiltering failed") + t.Cleanup(func() { + require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule") + }) + // Check if the rule is in the internal map rule, ok := r.rules[ruleKey.GetRuleID()] assert.True(t, ok, "Rule not found in internal map") @@ -346,10 +350,6 @@ func TestRouter_AddRouteFiltering(t *testing.T) { // Verify actual nftables rule content verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) - - // Clean up - err = r.DeleteRouteRule(ruleKey) - require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go index e27fce439fc..8ce73655d5f 100644 --- a/client/internal/acl/id/id.go +++ b/client/internal/acl/id/id.go @@ -1,8 +1,11 @@ package id import ( + "crypto/sha256" + "encoding/hex" "fmt" "net/netip" + "strconv" "github.com/netbirdio/netbird/client/firewall/manager" ) @@ -21,5 +24,41 @@ func GenerateRouteRuleKey( dPort *manager.Port, action manager.Action, ) RuleID { - return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) + manager.SortPrefixes(sources) + + h := sha256.New() + + // Write all fields to the hasher, with delimiters + h.Write([]byte("sources:")) + for _, src := range sources { + h.Write([]byte(src.String())) + h.Write([]byte(",")) + } + + h.Write([]byte("destination:")) + h.Write([]byte(destination.String())) + + h.Write([]byte("proto:")) + h.Write([]byte(proto)) + + h.Write([]byte("sPort:")) + if sPort != nil { + h.Write([]byte(sPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("dPort:")) + if dPort != nil { + h.Write([]byte(dPort.String())) + } else { + h.Write([]byte("")) + } + + h.Write([]byte("action:")) + h.Write([]byte(strconv.Itoa(int(action)))) + hash := hex.EncodeToString(h.Sum(nil)) + + // prepend destination prefix to be able to identify the rule + return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16])) } diff --git a/go.mod b/go.mod index e7137ce5bf5..cb37ca4bb6c 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.24.0 - golang.org/x/sys v0.21.0 + golang.org/x/crypto v0.28.0 + golang.org/x/sys v0.26.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -38,6 +38,7 @@ require ( github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 + github.com/davecgh/go-spew v1.1.1 github.com/eko/gocache/v3 v3.1.1 github.com/fsnotify/fsnotify v1.7.0 github.com/gliderlabs/ssh v0.3.4 @@ -45,7 +46,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/gopacket v1.1.19 - github.com/google/nftables v0.0.0-20220808154552-2eca00135732 + github.com/google/nftables v0.2.0 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 @@ -55,7 +56,7 @@ require ( github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 - github.com/mdlayher/socket v0.4.1 + github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 @@ -89,10 +90,10 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/net v0.26.0 + golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.7.0 - golang.org/x/term v0.21.0 + golang.org/x/sync v0.8.0 + golang.org/x/term v0.25.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -133,7 +134,6 @@ require ( github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect github.com/cpuguy83/dockercfg v0.3.1 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/text v0.19.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 4563dc9335f..05df5c66ed5 100644 --- a/go.sum +++ b/go.sum @@ -322,8 +322,8 @@ github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/martian/v3 v3.1.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= -github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= +github.com/google/nftables v0.2.0 h1:PbJwaBmbVLzpeldoeUKGkE2RjstrjPKMl6oLrfEJ6/8= +github.com/google/nftables v0.2.0/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20191218002539-d4f498aebedc/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= @@ -475,8 +475,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -871,8 +871,8 @@ golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= +golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/util/net/net.go b/util/net/net.go index 61b47dbe7d3..035d7552bc7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -11,7 +11,8 @@ import ( const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard - NetbirdFwmark = 0x1BD00 + NetbirdFwmark = 0x1BD00 + PreroutingFwmark = 0x1BD01 envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) From b2379175fe856e24c71263b7f0dcfac77ab8a722 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 10 Oct 2024 16:23:46 +0200 Subject: [PATCH 38/81] [signal] new signal dispatcher version (#2722) --- go.mod | 2 +- go.sum | 4 ++-- signal/server/signal.go | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index cb37ca4bb6c..e7e3c17a68a 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd - github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 05df5c66ed5..e9bc318d6fd 100644 --- a/go.sum +++ b/go.sum @@ -525,8 +525,8 @@ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811- github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= -github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= diff --git a/signal/server/signal.go b/signal/server/signal.go index 63cc43bd7ef..305fd052b2e 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -6,6 +6,7 @@ import ( "io" "time" + "github.com/netbirdio/signal-dispatcher/dispatcher" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -13,8 +14,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "github.com/netbirdio/signal-dispatcher/dispatcher" - "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" From 0e95f16cdd8462242a6912ef061347caba67041b Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 11 Oct 2024 16:24:30 +0200 Subject: [PATCH 39/81] [relay,client] Relay/fix/wg roaming (#2691) If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature. Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources. Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe. --- client/internal/peer/conn.go | 225 +++++++++++++----------- client/internal/wgproxy/ebpf/proxy.go | 35 +--- client/internal/wgproxy/ebpf/wrapper.go | 102 +++++++++-- client/internal/wgproxy/proxy.go | 5 +- client/internal/wgproxy/proxy_test.go | 2 +- client/internal/wgproxy/usp/proxy.go | 93 +++++++--- 6 files changed, 296 insertions(+), 166 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d4ad2396b3..1b740388d95 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -82,8 +82,6 @@ type Conn struct { config ConnConfig statusRecorder *Status wgProxyFactory *wgproxy.Factory - wgProxyICE wgproxy.Proxy - wgProxyRelay wgproxy.Proxy signaler *Signaler iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager @@ -106,7 +104,8 @@ type Conn struct { beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc - endpointRelay *net.UDPAddr + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy // for reconnection operations iCEDisconnected chan bool @@ -257,8 +256,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE.Set(StatusConnected) - - defer conn.updateIceState(iceConnInfo) - if conn.currentConnPriority > priority { + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) return } conn.log.Infof("set ICE to active connection") - endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) - if err != nil { - return + var ( + ep *net.UDPAddr + wgProxy wgproxy.Proxy + err error + ) + if iceConnInfo.RelayedOnLocal { + wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + return + } + ep = wgProxy.EndpointAddr() + conn.wgProxyICE = wgProxy + } else { + directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String()) + if err != nil { + log.Errorf("failed to resolveUDPaddr") + conn.handleConfigurationFailure(err, nil) + return + } + ep = directEp } - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) - - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } conn.workerRelay.DisableWgWatcher() - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { - if wgProxy != nil { - if err := wgProxy.CloseConn(); err != nil { - conn.log.Warnf("Failed to close turn connection: %v", err) - } - } - conn.log.Warnf("Failed to update wg peer configuration: %v", err) - return + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Pause() } - wgConfigWorkaround() - if conn.wgProxyICE != nil { - if err := conn.wgProxyICE.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } + if wgProxy != nil { + wgProxy.Work() } - conn.wgProxyICE = wgProxy + if err = conn.configureWGEndpoint(ep); err != nil { + conn.handleConfigurationFailure(err, wgProxy) + return + } + wgConfigWorkaround() conn.currentConnPriority = priority - + conn.statusICE.Set(StatusConnected) + conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } @@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Tracef("ICE connection state changed to %s", newState) + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + // switch back to relay connection - if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + if conn.isReadyToUpgrade() { conn.log.Debugf("ICE disconnected, set Relay to active connection") - err := conn.configureWGEndpoint(conn.endpointRelay) - if err != nil { + conn.wgProxyRelay.Work() + + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - select { - case conn.iCEDisconnected <- changed: - default: - } + conn.notifyReconnectLoopICEDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { if conn.ctx.Err() != nil { if err := rci.relayedConn.Close(); err != nil { - log.Warnf("failed to close unnecessary relayed connection: %v", err) + conn.log.Warnf("failed to close unnecessary relayed connection: %v", err) } return } - conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay.Set(StatusConnected) + conn.log.Debugf("Relay connection has been established, setup the WireGuard") - wgProxy := conn.wgProxyFactory.GetProxy() - endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + wgProxy, err := conn.newProxy(rci.relayedConn) if err != nil { conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) return } - conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) - - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - conn.endpointRelay = endpointUdpAddr - conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) - defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) - if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE.Get() == StatusConnected { - log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) - return - } + if conn.iceP2PIsActive() { + conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + conn.wgProxyRelay = wgProxy + conn.statusRelay.Set(StatusConnected) + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + return } - conn.connIDRelay = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { - conn.log.Errorf("Before add peer hook failed: %v", err) - } + if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) } - err = conn.configureWGEndpoint(endpointUdpAddr) - if err != nil { + wgProxy.Work() + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } - conn.log.Errorf("Failed to update wg peer configuration: %v", err) + conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err) return } conn.workerRelay.EnableWgWatcher(conn.ctx) - wgConfigWorkaround() - if conn.wgProxyRelay != nil { - if err := conn.wgProxyRelay.CloseConn(); err != nil { - conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) - } - } - conn.wgProxyRelay = wgProxy + wgConfigWorkaround() conn.currentConnPriority = connPriorityRelay - + conn.statusRelay.Set(StatusConnected) + conn.wgProxyRelay = wgProxy + conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) } @@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { return } - log.Debugf("relay connection is disconnected") + conn.log.Debugf("relay connection is disconnected") if conn.currentConnPriority == connPriorityRelay { - log.Debugf("clean up WireGuard config") - err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - if err != nil { + conn.log.Debugf("clean up WireGuard config") + if err := conn.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } if conn.wgProxyRelay != nil { - conn.endpointRelay = nil _ = conn.wgProxyRelay.CloseConn() conn.wgProxyRelay = nil } changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - - select { - case conn.relayDisconnected <- changed: - default: - } + conn.notifyReconnectLoopRelayDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil { conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) } } @@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool { return true } +func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, ip); err != nil { + return err + } + } + return nil +} + func (conn *Conn) freeUpConnID() { if conn.connIDRelay != "" { for _, hook := range conn.afterRemovePeerHooks { @@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() { } } -func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { - if !iceConnInfo.RelayedOnLocal { - return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil - } - conn.log.Debugf("setup ice turn connection") +func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { + conn.log.Debugf("setup proxied WireGuard connection") wgProxy := conn.wgProxyFactory.GetProxy() - ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) - if err != nil { + if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) - if errClose := wgProxy.CloseConn(); errClose != nil { - conn.log.Warnf("failed to close turn proxy connection: %v", errClose) + return nil, err + } + return wgProxy, nil +} + +func (conn *Conn) isReadyToUpgrade() bool { + return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay +} + +func (conn *Conn) iceP2PIsActive() bool { + return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected +} + +func (conn *Conn) removeWgPeer() error { + return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) +} + +func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { + select { + case conn.relayDisconnected <- changed: + default: + } +} + +func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { + select { + case conn.iCEDisconnected <- changed: + default: + } +} + +func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + if wgProxy != nil { + if ierr := wgProxy.CloseConn(); ierr != nil { + conn.log.Warnf("Failed to close wg proxy: %v", ierr) } - return nil, nil, err } - return ep, wgProxy, nil + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.Work() + } } func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/internal/wgproxy/ebpf/proxy.go index 27ede3ef1d0..e850f4533ce 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -5,7 +5,6 @@ package ebpf import ( "context" "fmt" - "io" "net" "os" "sync" @@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ @@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { - defer p.removeTurnConn(endpointPort) - - var ( - err error - n int - ) - buf := make([]byte, 1500) - for ctx.Err() == nil { - n, err = remoteConn.Read(buf) - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } - return - } - - if err := p.sendPkg(buf[:n], endpointPort); err != nil { - if ctx.Err() != nil || p.ctx.Err() != nil { - return - } - log.Errorf("failed to write out turn pkg to local conn: %v", err) - } - } -} - // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { @@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { return packetConn, nil } -func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { +func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { localhost := net.ParseIP("127.0.0.1") payload := gopacket.Payload(data) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go index c5639f840cc..b6a8ac45228 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -4,8 +4,13 @@ package ebpf import ( "context" + "errors" "fmt" + "io" "net" + "sync" + + log "github.com/sirupsen/logrus" ) // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call @@ -13,20 +18,55 @@ type ProxyWrapper struct { WgeBPFProxy *WGEBPFProxy remoteConn net.Conn - cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread -} + ctx context.Context + cancel context.CancelFunc -func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - ctxConn, cancel := context.WithCancel(ctx) - addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) + wgEndpointAddr *net.UDPAddr + + pausedMu sync.Mutex + paused bool + isStarted bool +} +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { + addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { - cancel() - return nil, fmt.Errorf("add turn conn: %w", err) + return fmt.Errorf("add turn conn: %w", err) + } + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + p.wgEndpointAddr = addr + return err +} + +func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { + return p.wgEndpointAddr +} + +func (p *ProxyWrapper) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) } - e.remoteConn = remoteConn - e.cancel = cancel - return addr, err +} + +func (p *ProxyWrapper) Pause() { + if p.remoteConn == nil { + return + } + + log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map @@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error { } return nil } + +func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { + defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + + buf := make([]byte, 1500) + for { + n, err := p.readFromRemote(ctx, buf) + if err != nil { + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) + p.pausedMu.Unlock() + + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) + } + } +} + +func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return 0, ctx.Err() + } + if !errors.Is(err, io.EOF) { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + } + return 0, err + } + return n, nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index 96fae8dd103..558121cdd5a 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -7,6 +7,9 @@ import ( // Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) error + EndpointAddr() *net.UDPAddr + Work() + Pause() CloseConn() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go index b09e6be555f..b88ff3f83c1 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/internal/wgproxy/proxy_test.go @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go index 83a8725d899..f73500717a9 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/internal/wgproxy/usp/proxy.go @@ -15,13 +15,17 @@ import ( // WGUserSpaceProxy proxies type WGUserSpaceProxy struct { localWGListenPort int - ctx context.Context - cancel context.CancelFunc remoteConn net.Conn localConn net.Conn + ctx context.Context + cancel context.CancelFunc closeMu sync.Mutex closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool } // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation @@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { return p } -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { - p.ctx, p.cancel = context.WithCancel(ctx) - - p.remoteConn = remoteConn - - var err error +// AddTurnConn +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { dialer := net.Dialer{} - p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err + return err } - go p.proxyToRemote() - go p.proxyToLocal() + p.ctx, p.cancel = context.WithCancel(ctx) + p.localConn = localConn + p.remoteConn = remoteConn - return p.localConn.LocalAddr(), err + return err +} + +func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { + if p.localConn == nil { + return nil + } + endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String()) + return endpointUdpAddr +} + +// Work starts the proxy or resumes it if it was paused +func (p *WGUserSpaceProxy) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + if !p.isStarted { + p.isStarted = true + go p.proxyToRemote(p.ctx) + go p.proxyToLocal(p.ctx) + } +} + +// Pause pauses the proxy from receiving data from the remote peer +func (p *WGUserSpaceProxy) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() } // CloseConn close the localConn @@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote() { +func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for ctx.Err() == nil { n, err := p.localConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to read from wg interface conn: %s", err) @@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() { _, err = p.remoteConn.Write(buf[:n]) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } @@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() { } // proxyToLocal proxies from the Remote peer to local WireGuard -func (p *WGUserSpaceProxy) proxyToLocal() { +// if the proxy is paused it will drain the remote conn and drop the packets +func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) @@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() { }() buf := make([]byte, 1500) - for p.ctx.Err() == nil { + for { n, err := p.remoteConn.Read(buf) if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + _, err = p.localConn.Write(buf[:n]) + p.pausedMu.Unlock() + if err != nil { - if p.ctx.Err() != nil { + if ctx.Err() != nil { return } log.Debugf("failed to write to wg interface conn: %s", err) From da3a053e2bed950bf9cf382f0690435548221745 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 12 Oct 2024 08:35:51 +0200 Subject: [PATCH 40/81] [management] Refactor getAccountIDWithAuthorizationClaims (#2715) This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance. - have dedicated methods to handle possible cases - introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods - Remove GetAccount and SaveAccount dependency - added tests --- management/server/account.go | 337 +++++++++++++++++----------- management/server/account_test.go | 278 +++++++++++------------ management/server/sql_store.go | 37 +++ management/server/sql_store_test.go | 60 +++++ management/server/store.go | 2 + 5 files changed, 441 insertions(+), 273 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 6ee0015f86f..c468b5eccaf 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" b64 "encoding/base64" + "errors" "fmt" "hash/crc32" "math/rand" @@ -50,6 +51,8 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -1285,7 +1288,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { return "", err } return account.Id, nil @@ -1300,28 +1303,39 @@ func isNil(i idp.Manager) bool { } // addAccountIDToIDPAppMeta update user's app metadata in idp manager -func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error { +func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { + accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + if err != nil { + return err + } + cachedAccount := &Account{ + Id: accountID, + Users: make(map[string]*User), + } + for _, user := range accountUsers { + cachedAccount.Users[user.Id] = user + } // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, cachedAccount) if err != nil { return err } - if user != nil && user.AppMetadata.WTAccountID == account.Id { + if user != nil && user.AppMetadata.WTAccountID == accountID { // it was already set, so we skip the unnecessary update log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s", - account.Id, userID) + accountID, userID) return nil } - err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id}) + err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err != nil { return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err) } // refresh cache to reflect the update - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { return err } @@ -1545,48 +1559,69 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun return am.cacheManager.Set(am.ctx, accountID, data, cacheStore.WithExpiration(cacheEntryExpiration())) } -// updateAccountDomainAttributes updates the account domain attributes and then, saves the account -func (am *DefaultAccountManager) updateAccountDomainAttributes(ctx context.Context, account *Account, claims jwtclaims.AuthorizationClaims, +// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes +func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims, primaryDomain bool, ) error { + if claims.Domain == "" { + log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + return nil + } - if claims.Domain != "" { - account.IsDomainPrimaryAccount = primaryDomain + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlockAccount() - lowerDomain := strings.ToLower(claims.Domain) - userObj := account.Users[claims.UserId] - if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { - account.Domain = lowerDomain - } - // prevent updating category for different domain until admin logs in - if account.Domain == lowerDomain { - account.DomainCategory = claims.DomainCategory - } - } else { - log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return err } - err := am.Store.SaveAccount(ctx, account) + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return nil + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { + log.WithContext(ctx).Errorf("error getting user: %v", err) return err } - return nil + + newDomain := accountDomain + newCategoty := domainCategory + + lowerDomain := strings.ToLower(claims.Domain) + if accountDomain != lowerDomain && user.HasAdminPower() { + newDomain = lowerDomain + } + + if accountDomain == lowerDomain { + newCategoty = claims.DomainCategory + } + + return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. +// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, +// we compare the account's ID with the domain account ID, and if they don't match, we set the account as +// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain +// was previously unclassified or classified as public so N users that logged int that time, has they own account +// and peers that shouldn't be lost. func (am *DefaultAccountManager) handleExistingUserAccount( ctx context.Context, - existingAcc *Account, - primaryDomain bool, + userAccountID string, + domainAccountID string, claims jwtclaims.AuthorizationClaims, ) error { - err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain) + primaryDomain := domainAccountID == "" || userAccountID == domainAccountID + err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain) if err != nil { return err } // we should register the account ID to this user's metadata in our IDP manager - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID) if err != nil { return err } @@ -1594,44 +1629,58 @@ func (am *DefaultAccountManager) handleExistingUserAccount( return nil } -// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, +// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account, // otherwise it will create a new account and make it primary account for the domain. -func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } - var ( - account *Account - err error - ) + lowerDomain := strings.ToLower(claims.Domain) - // if domain already has a primary account, add regular user - if domainAcc != nil { - account = domainAcc - account.Users[claims.UserId] = NewRegularUser(claims.UserId) - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } else { - account, err = am.newAccount(ctx, claims.UserId, lowerDomain) - if err != nil { - return nil, err - } - err = am.updateAccountDomainAttributes(ctx, account, claims, true) - if err != nil { - return nil, err - } + + newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) + if err != nil { + return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account) + newAccount.Domain = lowerDomain + newAccount.DomainCategory = claims.DomainCategory + newAccount.IsDomainPrimaryAccount = true + + err = am.Store.SaveAccount(ctx, newAccount) if err != nil { - return nil, err + return "", err + } + + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + if err != nil { + return "", err + } + + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + + return newAccount.Id, nil +} + +func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) + defer unlockAccount() + + usersMap := make(map[string]*User) + usersMap[claims.UserId] = NewRegularUser(claims.UserId) + err := am.Store.SaveUsers(domainAccountID, usersMap) + if err != nil { + return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.UserJoined, nil) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID) + if err != nil { + return "", err + } - return account, nil + am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil) + + return domainAccountID, nil } // redeemInvite checks whether user has been invited and redeems the invite @@ -1775,7 +1824,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s // GetAccountIDFromToken returns an account ID associated with this token. func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return "", "", fmt.Errorf("user ID is empty") + return "", "", errors.New(emptyUserID) } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1961,16 +2010,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } // getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. +// if domain is not private or domain is invalid, it will return the account ID by user ID. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // // Use cases: // -// New user + New account + New domain -> create account, user role = admin (if private domain, index domain) +// New user + New account + New domain -> create account, user role = owner (if private domain, index domain) // -// New user + New account + Existing Private Domain -> add user to the existing account, user role = regular (not admin) +// New user + New account + Existing Private Domain -> add user to the existing account, user role = user (not admin) // -// New user + New account + Existing Public Domain -> create account, user role = admin +// New user + New account + Existing Public Domain -> create account, user role = owner // // Existing user + Existing account + Existing Domain -> Nothing changes (if private, index domain) // @@ -1980,98 +2030,123 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) + if claims.UserId == "" { - return "", fmt.Errorf("user ID is empty") + return "", errors.New(emptyUserID) } - // if Account ID is part of the claims - // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - if claims.AccountId != "" { - exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { - return "", err - } - if !exists { - return "", status.Errorf(status.NotFound, "account %s does not exist", claims.AccountId) - } - return claims.AccountId, nil - } return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) - } else if claims.AccountId != "" { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err != nil { - return "", err - } + } - if userAccountID != claims.AccountId { - return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) - } + if claims.AccountId != "" { + return am.handlePrivateAccountWithIDFromClaim(ctx, claims) + } - domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) - if err != nil { + // We checked if the domain has a primary account already + domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain) + if cancel != nil { + defer cancel() + } + if err != nil { + return "", err + } + + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != "" { + if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil { return "", err } - if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { - return userAccountID, nil - } + return userAccountID, nil } - start := time.Now() - unlock := am.Store.AcquireGlobalLock(ctx) - defer unlock() - log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) + if domainAccountID != "" { + return am.addNewUserToDomainAccount(ctx, domainAccountID, claims) + } + + return am.addNewPrivateAccount(ctx, domainAccountID, claims) +} +func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + if domainAccountID != "" { + return domainAccountID, nil, nil + } + + log.WithContext(ctx).Debugf("no primary account found for domain %s, acquiring global lock", domain) + cancel := am.Store.AcquireGlobalLock(ctx) + + // check again if the domain has a primary account because of simultaneous requests + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", nil, err + } + + return domainAccountID, cancel, nil +} + +func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + if err != nil { + log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) + return "", err + } + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } + + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) + return "", err + } + + if domainIsUpToDate(accountDomain, domainCategory, claims) { + return claims.AccountId, nil + } // We checked if the domain has a primary account already domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + if handleNotFound(err) != nil { + log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) + return "", err + } + + err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims) if err != nil { - // if NotFound we are good to continue, otherwise return error - e, ok := status.FromError(err) - if !ok || e.Type() != status.NotFound { - return "", err - } + return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) - if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) - defer unlockAccount() - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) - if err != nil { - return "", err - } - // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, - // we compare the account's ID with the domain account ID, and if they don't match, we set the account as - // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain - // was previously unclassified or classified as public so N users that logged int that time, has they own account - // and peers that shouldn't be lost. - primaryDomain := domainAccountID == "" || account.Id == domainAccountID - if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { - return "", err - } + return claims.AccountId, nil +} - return account.Id, nil - } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - var domainAccount *Account - if domainAccountID != "" { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) - defer unlockAccount() - domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) - if err != nil { - return "", err - } - } +func handleNotFound(err error) error { + if err == nil { + return nil + } - account, err := am.handleNewUserAccount(ctx, domainAccount, claims) - if err != nil { - return "", err - } - return account.Id, nil - } else { - // other error - return "", err + e, ok := status.FromError(err) + if !ok || e.Type() != status.NotFound { + return err } + return nil +} + +func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { + return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 4dd58e88e0d..b20071cba04 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -465,22 +465,6 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims - type test struct { - name string - inputClaims jwtclaims.AuthorizationClaims - inputInitUserParams initUserParams - inputUpdateAttrs bool - inputUpdateClaimAccount bool - testingFunc require.ComparisonAssertionFunc - expectedMSG string - expectedUserRole UserRole - expectedDomainCategory string - expectedDomain string - expectedPrimaryDomainStatus bool - expectedCreatedBy string - expectedUsers []string - } - var ( publicDomain = "public.com" privateDomain = "private.com" @@ -492,143 +476,153 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { UserId: "defaultUser", } - testCase1 := test{ - name: "New User With Public Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: publicDomain, - UserId: "pub-domain-user", - DomainCategory: PublicCategory, - }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomainCategory: "", - expectedDomain: publicDomain, - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pub-domain-user", - expectedUsers: []string{"pub-domain-user"}, - } - initUnknown := defaultInitAccount initUnknown.DomainCategory = UnknownCategory initUnknown.Domain = unknownDomain - testCase2 := test{ - name: "New User With Unknown Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: unknownDomain, - UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, - }, - inputInitUserParams: initUnknown, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: unknownDomain, - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "unknown-domain-user", - expectedUsers: []string{"unknown-domain-user"}, - } - - testCase3 := test{ - name: "New User With Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, - }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - privateInitAccount := defaultInitAccount privateInitAccount.Domain = privateDomain privateInitAccount.DomainCategory = PrivateCategory - testCase4 := test{ - name: "New Regular User With Existing Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: privateDomain, - UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + testCases := []struct { + name string + inputClaims jwtclaims.AuthorizationClaims + inputInitUserParams initUserParams + inputUpdateAttrs bool + inputUpdateClaimAccount bool + testingFunc require.ComparisonAssertionFunc + expectedMSG string + expectedUserRole UserRole + expectedDomainCategory string + expectedDomain string + expectedPrimaryDomainStatus bool + expectedCreatedBy string + expectedUsers []string + }{ + { + name: "New User With Public Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: publicDomain, + UserId: "pub-domain-user", + DomainCategory: PublicCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomainCategory: "", + expectedDomain: publicDomain, + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pub-domain-user", + expectedUsers: []string{"pub-domain-user"}, + }, + { + name: "New User With Unknown Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: unknownDomain, + UserId: "unknown-domain-user", + DomainCategory: UnknownCategory, + }, + inputInitUserParams: initUnknown, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: unknownDomain, + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "unknown-domain-user", + expectedUsers: []string{"unknown-domain-user"}, + }, + { + name: "New User With Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputUpdateAttrs: true, - inputInitUserParams: privateInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, - expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, - } - - testCase5 := test{ - name: "Existing User With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "New Regular User With Existing Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: privateDomain, + UserId: "new-pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputUpdateAttrs: true, + inputInitUserParams: privateInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleUser, + expectedDomain: privateDomain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, + }, + { + name: "Existing User With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase6 := test{ - name: "Existing Account Id With Existing Reclassified Private Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: defaultInitAccount.Domain, - UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + { + name: "Existing Account Id With Existing Reclassified Private Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: defaultInitAccount.Domain, + UserId: defaultInitAccount.UserId, + DomainCategory: PrivateCategory, + }, + inputUpdateClaimAccount: true, + inputInitUserParams: defaultInitAccount, + testingFunc: require.Equal, + expectedMSG: "account IDs should match", + expectedUserRole: UserRoleOwner, + expectedDomain: defaultInitAccount.Domain, + expectedDomainCategory: PrivateCategory, + expectedPrimaryDomainStatus: true, + expectedCreatedBy: defaultInitAccount.UserId, + expectedUsers: []string{defaultInitAccount.UserId}, }, - inputUpdateClaimAccount: true, - inputInitUserParams: defaultInitAccount, - testingFunc: require.Equal, - expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, - expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, - expectedPrimaryDomainStatus: true, - expectedCreatedBy: defaultInitAccount.UserId, - expectedUsers: []string{defaultInitAccount.UserId}, - } - - testCase7 := test{ - name: "User With Private Category And Empty Domain", - inputClaims: jwtclaims.AuthorizationClaims{ - Domain: "", - UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + { + name: "User With Private Category And Empty Domain", + inputClaims: jwtclaims.AuthorizationClaims{ + Domain: "", + UserId: "pvt-domain-user", + DomainCategory: PrivateCategory, + }, + inputInitUserParams: defaultInitAccount, + testingFunc: require.NotEqual, + expectedMSG: "account IDs shouldn't match", + expectedUserRole: UserRoleOwner, + expectedDomain: "", + expectedDomainCategory: "", + expectedPrimaryDomainStatus: false, + expectedCreatedBy: "pvt-domain-user", + expectedUsers: []string{"pvt-domain-user"}, }, - inputInitUserParams: defaultInitAccount, - testingFunc: require.NotEqual, - expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, - expectedDomain: "", - expectedDomainCategory: "", - expectedPrimaryDomainStatus: false, - expectedCreatedBy: "pvt-domain-user", - expectedUsers: []string{"pvt-domain-user"}, - } - - for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} { + } + + for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -640,7 +634,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { require.NoError(t, err, "get init account failed") if testCase.inputUpdateAttrs { - err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) + err = manager.updateAccountDomainAttributesIfNotUpToDate(context.Background(), initAccount.Id, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 615203bee38..de3dfa9455e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -323,6 +323,29 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. return nil } +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + accountCopy := Account{ + Domain: domain, + DomainCategory: category, + IsDomainPrimaryAccount: isPrimaryDomain, + } + + fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} + result := s.db.WithContext(ctx).Model(&Account{}). + Select(fieldsToUpdate). + Where(idQueryCondition, accountID). + Updates(&accountCopy) + if result.Error != nil { + return result.Error + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account %s", accountID) + } + + return nil +} + func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -518,6 +541,20 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } +func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + var users []*User + result := s.db.Find(&users, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting users from store") + } + + return users, nil +} + func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 06e118fd22b..20e812ea709 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1191,3 +1191,63 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { }) assert.NoError(t, err) } + +func TestSqlite_GetAccoundUsers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + users, err := store.GetAccountUsers(context.Background(), accountID) + require.NoError(t, err) + require.Len(t, users, len(account.Users)) +} + +func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + t.Run("Should update attributes with public domain", func(t *testing.T) { + require.NoError(t, err) + domain := "example.com" + category := "public" + IsDomainPrimaryAccount := false + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should update attributes with private domain", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + require.NoError(t, err) + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.Equal(t, IsDomainPrimaryAccount, account.IsDomainPrimaryAccount) + }) + + t.Run("Should fail when account does not exist", func(t *testing.T) { + require.NoError(t, err) + domain := "test.com" + category := "private" + IsDomainPrimaryAccount := true + err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + require.Error(t, err) + }) + +} diff --git a/management/server/store.go b/management/server/store.go index d914bb8f7d5..131fd8aaab6 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -58,9 +58,11 @@ type Store interface { GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error + UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) + GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error From 3a88ac78ff80b77eecfdcf9d7d66663b017419aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 12 Oct 2024 10:44:48 +0200 Subject: [PATCH 41/81] [client] Add table filter rules using iptables (#2727) This specifically concerns the established/related rule since this one is not compatible with iptables-nft even if it is generated the same way by iptables-translate. --- client/firewall/nftables/manager_linux.go | 47 +++-- .../firewall/nftables/manager_linux_test.go | 1 + client/firewall/nftables/router_linux.go | 186 ++++++++++++------ 3 files changed, 148 insertions(+), 86 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index d2258ae0869..01b08bd7111 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -315,28 +315,33 @@ func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain * rule := &nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 1, - }, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, + Exprs: getEstablishedExprs(1), } conn.InsertRule(rule) } + +func getEstablishedExprs(register uint32) []expr.Any { + return []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: register, + }, + &expr.Bitwise{ + SourceRegister: register, + DestRegister: register, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: register, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 904050a517f..bbe18ab0714 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -109,6 +109,7 @@ func TestNftablesManager(t *testing.T) { Register: 1, Data: []byte{0, 0, 0, 0}, }, + &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 9b8fdbda53d..404ba695780 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -10,6 +10,7 @@ import ( "net/netip" "strings" + "github.com/coreos/go-iptables/iptables" "github.com/davecgh/go-spew/spew" "github.com/google/nftables" "github.com/google/nftables/binaryutil" @@ -81,7 +82,7 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa } } - err = r.cleanUpDefaultForwardRules() + err = r.removeAcceptForwardRules() if err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } @@ -98,40 +99,7 @@ func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() - return r.cleanUpDefaultForwardRules() -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list chains: %v", err) - } - - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { - continue - } - - rules, err := r.conn.GetRules(r.filterTable, chain) - if err != nil { - return fmt.Errorf("get rules: %v", err) - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { - if err := r.conn.DelRule(rule); err != nil { - return fmt.Errorf("delete rule: %v", err) - } - } - } - } - - return r.conn.Flush() + return r.removeAcceptForwardRules() } func (r *router) loadFilterTable() (*nftables.Table, error) { @@ -167,7 +135,9 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeNAT, }) - r.acceptForwardRules() + if err := r.acceptForwardRules(); err != nil { + log.Errorf("failed to add accept rules for the forward chain: %s", err) + } if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) @@ -577,19 +547,60 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. -func (r *router) acceptForwardRules() { +func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") - return + return nil + } + + fw := "iptables" + + defer func() { + log.Debugf("Used %s to add accept forward rules", fw) + }() + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + // filter table exists but iptables is not + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + + fw = "nftables" + return r.acceptForwardRulesNftables() } + return r.acceptForwardRulesIptables(ipt) +} + +func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + } else { + log.Debugf("added iptables rule: %v", rule) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (r *router) getAcceptForwardRules() [][]string { + intf := r.wgIface.Name() + return [][]string{ + {"-i", intf, "-j", "ACCEPT"}, + {"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, + } +} + +func (r *router) acceptForwardRulesNftables() error { intf := ifname(r.wgIface.Name()) // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -609,6 +620,15 @@ func (r *router) acceptForwardRules() { } r.conn.InsertRule(iifRule) + oifExprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, @@ -619,36 +639,72 @@ func (r *router) acceptForwardRules() { Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, }, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: intf, - }, - &expr.Ct{ - Key: expr.CtKeySTATE, - Register: 2, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), - Xor: binaryutil.NativeEndian.PutUint32(0), - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: []byte{0, 0, 0, 0}, - }, - &expr.Counter{}, - &expr.Verdict{Kind: expr.VerdictAccept}, - }, + Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } r.conn.InsertRule(oifRule) + + return nil +} + +func (r *router) removeAcceptForwardRules() error { + if r.filterTable == nil { + return nil + } + + // Try iptables first and fallback to nftables if iptables is not available + ipt, err := iptables.New() + if err != nil { + log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) + return r.removeAcceptForwardRulesNftables() + } + + return r.removeAcceptForwardRulesIptables(ipt) +} + +func (r *router) removeAcceptForwardRulesNftables() error { + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { + var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { + if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) } // RemoveNatRule removes a nftables rule pair from nat chains From d93dd4fc7f47c9e1ac597af2570e6faaa36f1219 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 12 Oct 2024 18:21:34 +0200 Subject: [PATCH 42/81] [relay-server] Move the handshake logic to separated struct (#2648) * Move the handshake logic to separated struct - The server will response to the client after it ready to process the peer - Preload the response messages * Fix deprecated lint issue * Fix error handling * [relay-server] Relay measure auth time (#2675) Measure the Relay client's authentication time --- relay/metrics/realy.go | 52 +++++++++++-- relay/server/handshake.go | 153 ++++++++++++++++++++++++++++++++++++++ relay/server/relay.go | 127 ++++++------------------------- 3 files changed, 223 insertions(+), 109 deletions(-) create mode 100644 relay/server/handshake.go diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go index 13799713a23..4dc98a0e009 100644 --- a/relay/metrics/realy.go +++ b/relay/metrics/realy.go @@ -16,8 +16,10 @@ const ( type Metrics struct { metric.Meter - TransferBytesSent metric.Int64Counter - TransferBytesRecv metric.Int64Counter + TransferBytesSent metric.Int64Counter + TransferBytesRecv metric.Int64Counter + AuthenticationTime metric.Float64Histogram + PeerStoreTime metric.Float64Histogram peers metric.Int64UpDownCounter peerActivityChan chan string @@ -52,11 +54,23 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { return nil, err } + authTime, err := meter.Float64Histogram("relay_peer_authentication_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + + peerStoreTime, err := meter.Float64Histogram("relay_peer_store_time_milliseconds", metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...)) + if err != nil { + return nil, err + } + m := &Metrics{ - Meter: meter, - TransferBytesSent: bytesSent, - TransferBytesRecv: bytesRecv, - peers: peers, + Meter: meter, + TransferBytesSent: bytesSent, + TransferBytesRecv: bytesRecv, + AuthenticationTime: authTime, + PeerStoreTime: peerStoreTime, + peers: peers, ctx: ctx, peerActivityChan: make(chan string, 10), @@ -89,6 +103,16 @@ func (m *Metrics) PeerConnected(id string) { m.peerLastActive[id] = time.Time{} } +// RecordAuthenticationTime measures the time taken for peer authentication +func (m *Metrics) RecordAuthenticationTime(duration time.Duration) { + m.AuthenticationTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + +// RecordPeerStoreTime measures the time to store the peer in map +func (m *Metrics) RecordPeerStoreTime(duration time.Duration) { + m.PeerStoreTime.Record(m.ctx, float64(duration.Nanoseconds())/1e6) +} + // PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections func (m *Metrics) PeerDisconnected(id string) { m.peers.Add(m.ctx, -1) @@ -134,3 +158,19 @@ func (m *Metrics) readPeerActivity() { } } } + +func getStandardBucketBoundaries() []float64 { + return []float64{ + 0.1, + 0.5, + 1, + 5, + 10, + 50, + 100, + 500, + 1000, + 5000, + 10000, + } +} diff --git a/relay/server/handshake.go b/relay/server/handshake.go new file mode 100644 index 00000000000..0257300f82c --- /dev/null +++ b/relay/server/handshake.go @@ -0,0 +1,153 @@ +package server + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck + "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck + authmsg "github.com/netbirdio/netbird/relay/messages/auth" +) + +// preparedMsg contains the marshalled success response messages +type preparedMsg struct { + responseHelloMsg []byte + responseAuthMsg []byte +} + +func newPreparedMsg(instanceURL string) (*preparedMsg, error) { + rhm, err := marshalResponseHelloMsg(instanceURL) + if err != nil { + return nil, err + } + + ram, err := messages.MarshalAuthResponse(instanceURL) + if err != nil { + return nil, fmt.Errorf("failed to marshal auth response msg: %w", err) + } + + return &preparedMsg{ + responseHelloMsg: rhm, + responseAuthMsg: ram, + }, nil +} + +func marshalResponseHelloMsg(instanceURL string) ([]byte, error) { + addr := &address.Address{URL: instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal response address: %w", err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, fmt.Errorf("failed to marshal hello response: %w", err) + } + return responseMsg, nil +} + +type handshake struct { + conn net.Conn + validator auth.Validator + preparedMsg *preparedMsg + + handshakeMethodAuth bool + peerID string +} + +func (h *handshake) handshakeReceive() ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := h.conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("read from %s: %w", h.conn.RemoteAddr(), err) + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return nil, fmt.Errorf("validate version from %s: %w", h.conn.RemoteAddr(), err) + } + + msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + return nil, fmt.Errorf("determine message type from %s: %w", h.conn.RemoteAddr(), err) + } + + var ( + bytePeerID []byte + peerID string + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + bytePeerID, peerID, err = h.handleHelloMsg(buf[messages.SizeOfProtoHeader:n]) + case messages.MsgTypeAuth: + h.handshakeMethodAuth = true + bytePeerID, peerID, err = h.handleAuthMsg(buf[messages.SizeOfProtoHeader:n]) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, h.conn.RemoteAddr()) + } + if err != nil { + return nil, err + } + h.peerID = peerID + return bytePeerID, nil +} + +func (h *handshake) handshakeResponse() error { + var responseMsg []byte + if h.handshakeMethodAuth { + responseMsg = h.preparedMsg.responseAuthMsg + } else { + responseMsg = h.preparedMsg.responseHelloMsg + } + + if _, err := h.conn.Write(responseMsg); err != nil { + return fmt.Errorf("handshake response write to %s (%s): %w", h.peerID, h.conn.RemoteAddr(), err) + } + + return nil +} + +func (h *handshake) handleHelloMsg(buf []byte) ([]byte, string, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, h.conn.RemoteAddr()) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, "", fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := h.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} + +func (h *handshake) handleAuthMsg(buf []byte) ([]byte, string, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, "", fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := h.validator.Validate(authPayload); err != nil { + return nil, "", fmt.Errorf("validate %s (%s): %w", peerID, h.conn.RemoteAddr(), err) + } + + return rawPeerID, peerID, nil +} diff --git a/relay/server/relay.go b/relay/server/relay.go index 76c01a697fd..6cd8506ae96 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -7,16 +7,13 @@ import ( "net/url" "strings" "sync" + "time" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/metric" "github.com/netbirdio/netbird/relay/auth" - "github.com/netbirdio/netbird/relay/messages" //nolint:staticcheck - "github.com/netbirdio/netbird/relay/messages/address" - //nolint:staticcheck - authmsg "github.com/netbirdio/netbird/relay/messages/auth" "github.com/netbirdio/netbird/relay/metrics" ) @@ -28,6 +25,7 @@ type Relay struct { store *Store instanceURL string + preparedMsg *preparedMsg closed bool closeMu sync.RWMutex @@ -69,6 +67,12 @@ func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, valida return nil, fmt.Errorf("get instance URL: %v", err) } + r.preparedMsg, err = newPreparedMsg(r.instanceURL) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("prepare message: %v", err) + } + return r, nil } @@ -100,17 +104,22 @@ func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { // Accept start to handle a new peer connection func (r *Relay) Accept(conn net.Conn) { + acceptTime := time.Now() r.closeMu.RLock() defer r.closeMu.RUnlock() if r.closed { return } - peerID, err := r.handshake(conn) + h := handshake{ + conn: conn, + validator: r.validator, + preparedMsg: r.preparedMsg, + } + peerID, err := h.handshakeReceive() if err != nil { log.Errorf("failed to handshake: %s", err) - cErr := conn.Close() - if cErr != nil { + if cErr := conn.Close(); cErr != nil { log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) } return @@ -118,7 +127,9 @@ func (r *Relay) Accept(conn net.Conn) { peer := NewPeer(r.metrics, peerID, conn, r.store) peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + storeTime := time.Now() r.store.AddPeer(peer) + r.metrics.RecordPeerStoreTime(time.Since(storeTime)) r.metrics.PeerConnected(peer.String()) go func() { peer.Work() @@ -126,6 +137,12 @@ func (r *Relay) Accept(conn net.Conn) { peer.log.Debugf("relay connection closed") r.metrics.PeerDisconnected(peer.String()) }() + + if err := h.handshakeResponse(); err != nil { + log.Errorf("failed to send handshake response, close peer: %s", err) + peer.Close() + } + r.metrics.RecordAuthenticationTime(time.Since(acceptTime)) } // Shutdown closes the relay server @@ -151,99 +168,3 @@ func (r *Relay) Shutdown(ctx context.Context) { func (r *Relay) InstanceURL() string { return r.instanceURL } - -func (r *Relay) handshake(conn net.Conn) ([]byte, error) { - buf := make([]byte, messages.MaxHandshakeSize) - n, err := conn.Read(buf) - if err != nil { - return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err) - } - - _, err = messages.ValidateVersion(buf[:n]) - if err != nil { - return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) - } - - msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) - if err != nil { - return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) - } - - var ( - responseMsg []byte - peerID []byte - ) - switch msgType { - //nolint:staticcheck - case messages.MsgTypeHello: - peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - case messages.MsgTypeAuth: - peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) - default: - return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) - } - if err != nil { - return nil, err - } - - _, err = conn.Write(responseMsg) - if err != nil { - return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) - } - - return peerID, nil -} - -func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { - //nolint:staticcheck - rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) - - authMsg, err := authmsg.UnmarshalMsg(authData) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) - } - - //nolint:staticcheck - if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) - } - - addr := &address.Address{URL: r.instanceURL} - addrData, err := addr.Marshal() - if err != nil { - return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) - } - - //nolint:staticcheck - responseMsg, err := messages.MarshalHelloResponse(addrData) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) - } - return rawPeerID, responseMsg, nil -} - -func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { - rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) - if err != nil { - return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) - } - - peerID := messages.HashIDToString(rawPeerID) - - if err := r.validator.Validate(authPayload); err != nil { - return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) - } - - responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) - if err != nil { - return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) - } - - return rawPeerID, responseMsg, nil -} From 49e65109d25a98bfa73245e1252663d37184533a Mon Sep 17 00:00:00 2001 From: ctrl-zzz <78654296+ctrl-zzz@users.noreply.github.com> Date: Sun, 13 Oct 2024 14:52:43 +0200 Subject: [PATCH 43/81] Add session expire functionality based on inactivity (#2326) Implemented inactivity expiration by checking the status of a peer: after a configurable period of time following netbird down, the peer shows login required. --- management/server/account.go | 133 +++++++++ management/server/account_test.go | 315 +++++++++++++++++++++ management/server/activity/codes.go | 14 + management/server/file_store.go | 3 + management/server/http/accounts_handler.go | 3 + management/server/http/api/openapi.yml | 19 ++ management/server/http/api/types.gen.go | 21 +- management/server/http/peers_handler.go | 54 ++-- management/server/peer.go | 97 +++++-- management/server/peer/peer.go | 20 ++ management/server/peer_test.go | 62 ++++ 11 files changed, 682 insertions(+), 59 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c468b5eccaf..7c84ad1ca1b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -51,6 +51,7 @@ const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) @@ -181,6 +182,8 @@ type DefaultAccountManager struct { dnsDomain string peerLoginExpiry Scheduler + peerInactivityExpiry Scheduler + // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool @@ -198,6 +201,13 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements RegularUsersViewBlocked bool @@ -228,6 +238,9 @@ func (s *Settings) Copy() *Settings { GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -609,6 +622,60 @@ func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { return peers } +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + // GetPeers returns a list of all Account peers func (a *Account) GetPeers() []*nbpeer.Peer { var peers []*nbpeer.Peer @@ -975,6 +1042,7 @@ func BuildManager( dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), + peerInactivityExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, @@ -1103,6 +1171,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, account) } + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, err + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1113,6 +1186,26 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } + + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + + return nil +} + func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1148,6 +1241,43 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context } } +// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found +func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { + return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + log.Errorf("failed getting account %s expiring peers", account.Id) + return account.GetNextInactivePeerExpiration() + } + + expiredPeers := account.GetInactivePeers() + var peerIDs []string + for _, peer := range expiredPeers { + peerIDs = append(peerIDs, peer.ID) + } + + log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) + + if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { + log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) + return account.GetNextInactivePeerExpiration() + } + + return account.GetNextInactivePeerExpiration() + } +} + +// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { + am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) + if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) + } +} + // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { @@ -2412,6 +2542,9 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, }, } diff --git a/management/server/account_test.go b/management/server/account_test.go index b20071cba04..19514dad181 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1957,6 +1957,90 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestAccount_GetInactivePeers(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + testCases := []test{ + { + name: "Peers with inactivity expiration disabled, no expired peers", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + }, + "peer-2": { + InactivityExpirationEnabled: false, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Two peers expired", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-30 * time.Minute), + UserID: userID, + }, + "peer-2": { + ID: "peer-2", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC().Add(-45 * time.Second), + Connected: false, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-2 * time.Hour), + UserID: userID, + }, + "peer-3": { + ID: "peer-3", + InactivityExpirationEnabled: true, + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + }, + LastLogin: time.Now().UTC().Add(-1 * time.Hour), + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + "peer-2": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + } + + expiredPeers := account.GetInactivePeers() + assert.Len(t, expiredPeers, len(testCase.expectedPeers)) + for _, peer := range expiredPeers { + if _, ok := testCase.expectedPeers[peer.ID]; !ok { + t.Fatalf("expected to have peer %s expired", peer.ID) + } + } + }) + } +} + func TestAccount_GetPeersWithExpiration(t *testing.T) { type test struct { name string @@ -2026,6 +2110,75 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { } } +func TestAccount_GetPeersWithInactivity(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expectedPeers map[string]struct{} + } + + testCases := []test{ + { + name: "No account peers, no peers with expiration", + peers: map[string]*nbpeer.Peer{}, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration disabled, no peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{}, + }, + { + name: "Peers with login expiration enabled, return peers with expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + ID: "peer-1", + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expectedPeers: map[string]struct{}{ + "peer-1": {}, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + } + + actual := account.GetPeersWithInactivity() + assert.Len(t, actual, len(testCase.expectedPeers)) + if len(testCase.expectedPeers) > 0 { + for k := range testCase.expectedPeers { + contains := false + for _, peer := range actual { + if k == peer.ID { + contains = true + } + } + assert.True(t, contains) + } + } + }) + } +} + func TestAccount_GetNextPeerExpiration(t *testing.T) { type test struct { name string @@ -2187,6 +2340,168 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } } +func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { + type test struct { + name string + peers map[string]*nbpeer.Peer + expiration time.Duration + expirationEnabled bool + expectedNextRun bool + expectedNextExpiration time.Duration + } + + expectedNextExpiration := time.Minute + testCases := []test{ + { + name: "No peers, no expiration", + peers: map[string]*nbpeer.Peer{}, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "No connected peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: false, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Connected peers with disabled expiration, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + }, + InactivityExpirationEnabled: false, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "Expired peers, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + { + name: "To be expired peer, return expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: false, + LoginExpired: false, + LastSeen: time.Now().Add(-1 * time.Second), + }, + InactivityExpirationEnabled: true, + LastLogin: time.Now().UTC(), + UserID: userID, + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: true, + }, + InactivityExpirationEnabled: true, + UserID: userID, + }, + }, + expiration: time.Minute, + expirationEnabled: false, + expectedNextRun: true, + expectedNextExpiration: expectedNextExpiration, + }, + { + name: "Peers added with setup keys, no expiration", + peers: map[string]*nbpeer.Peer{ + "peer-1": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + "peer-2": { + Status: &nbpeer.PeerStatus{ + Connected: true, + LoginExpired: false, + }, + InactivityExpirationEnabled: true, + SetupKey: "key", + }, + }, + expiration: time.Second, + expirationEnabled: false, + expectedNextRun: false, + expectedNextExpiration: time.Duration(0), + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + account := &Account{ + Peers: testCase.peers, + Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + } + + expiration, ok := account.GetNextInactivePeerExpiration() + assert.Equal(t, testCase.expectedNextRun, ok) + if testCase.expectedNextRun { + assert.True(t, expiration >= 0 && expiration <= testCase.expectedNextExpiration) + } else { + assert.Equal(t, expiration, testCase.expectedNextExpiration) + } + }) + } +} + func TestAccount_SetJWTGroups(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4ee57f1817c..188494241c6 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -139,6 +139,13 @@ const ( PostureCheckUpdated Activity = 61 // PostureCheckDeleted indicates that the user deleted a posture check PostureCheckDeleted Activity = 62 + + PeerInactivityExpirationEnabled Activity = 63 + PeerInactivityExpirationDisabled Activity = 64 + + AccountPeerInactivityExpirationEnabled Activity = 65 + AccountPeerInactivityExpirationDisabled Activity = 66 + AccountPeerInactivityExpirationDurationUpdated Activity = 67 ) var activityMap = map[Activity]Code{ @@ -205,6 +212,13 @@ var activityMap = map[Activity]Code{ PostureCheckCreated: {"Posture check created", "posture.check.created"}, PostureCheckUpdated: {"Posture check updated", "posture.check.updated"}, PostureCheckDeleted: {"Posture check deleted", "posture.check.deleted"}, + + PeerInactivityExpirationEnabled: {"Peer inactivity expiration enabled", "peer.inactivity.expiration.enable"}, + PeerInactivityExpirationDisabled: {"Peer inactivity expiration disabled", "peer.inactivity.expiration.disable"}, + + AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, + AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, + AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, } // StringCode returns a string code of the activity diff --git a/management/server/file_store.go b/management/server/file_store.go index df3e9bb7757..561e133cec8 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -95,6 +95,9 @@ func restore(ctx context.Context, file string) (*FileStore, error) { account.Settings = &Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: DefaultPeerLoginExpiration, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, } } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 91caa15128a..4d4066de487 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -78,6 +78,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)), } if req.Settings.Extra != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index fd0343e97bb..9d51482481a 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,14 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + peer_inactivity_expiration_enabled: + description: Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + type: boolean + example: true + peer_inactivity_expiration: + description: Period of time of inactivity after which peer session expires (seconds). + type: integer + example: 43200 regular_users_view_blocked: description: Allows blocking regular users from viewing parts of the system. type: boolean @@ -81,6 +89,8 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - peer_inactivity_expiration_enabled + - peer_inactivity_expiration - regular_users_view_blocked AccountExtraSettings: type: object @@ -243,6 +253,9 @@ components: login_expiration_enabled: type: boolean example: false + inactivity_expiration_enabled: + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -251,6 +264,7 @@ components: - name - ssh_enabled - login_expiration_enabled + - inactivity_expiration_enabled Peer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -327,6 +341,10 @@ components: type: string format: date-time example: "2023-05-05T09:00:35.477782Z" + inactivity_expiration_enabled: + description: Indicates whether peer inactivity expiration has been enabled or not + type: boolean + example: false approval_required: description: (Cloud only) Indicates whether peer needs approval type: boolean @@ -354,6 +372,7 @@ components: - last_seen - login_expiration_enabled - login_expired + - inactivity_expiration_enabled - os - ssh_enabled - user_id diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 570ec03c5bc..e2870d5d8ef 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -220,6 +220,12 @@ type AccountSettings struct { // JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups. JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"` + // PeerInactivityExpiration Period of time of inactivity after which peer session expires (seconds). + PeerInactivityExpiration int `json:"peer_inactivity_expiration"` + + // PeerInactivityExpirationEnabled Enables or disables peer inactivity expiration globally. After peer's session has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). + PeerInactivityExpirationEnabled bool `json:"peer_inactivity_expiration_enabled"` + // PeerLoginExpiration Period of time after which peer login expires (seconds). PeerLoginExpiration int `json:"peer_login_expiration"` @@ -538,6 +544,9 @@ type Peer struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -613,6 +622,9 @@ type PeerBatch struct { // Id Peer ID Id string `json:"id"` + // InactivityExpirationEnabled Indicates whether peer inactivity expiration has been enabled or not + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + // Ip Peer's IP address Ip string `json:"ip"` @@ -677,10 +689,11 @@ type PeerNetworkRangeCheckAction string // PeerRequest defines model for PeerRequest. type PeerRequest struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` - LoginExpirationEnabled bool `json:"login_expiration_enabled"` - Name string `json:"name"` - SshEnabled bool `json:"ssh_enabled"` + ApprovalRequired *bool `json:"approval_required,omitempty"` + InactivityExpirationEnabled bool `json:"inactivity_expiration_enabled"` + LoginExpirationEnabled bool `json:"login_expiration_enabled"` + Name string `json:"name"` + SshEnabled bool `json:"ssh_enabled"` } // PersonalAccessToken defines model for PersonalAccessToken. diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 4fbbc3106d3..a5856a0e43c 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,6 +7,8 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -14,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -87,6 +88,8 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, SSHEnabled: req.SshEnabled, Name: req.Name, LoginExpirationEnabled: req.LoginExpirationEnabled, + + InactivityExpirationEnabled: req.InactivityExpirationEnabled, } if req.ApprovalRequired != nil { @@ -331,29 +334,30 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD } return &api.Peer{ - Id: peer.ID, - Name: peer.Name, - Ip: peer.IP.String(), - ConnectionIp: peer.Location.ConnectionIP.String(), - Connected: peer.Status.Connected, - LastSeen: peer.Status.LastSeen, - Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), - KernelVersion: peer.Meta.KernelVersion, - GeonameId: int(peer.Location.GeoNameID), - Version: peer.Meta.WtVersion, - Groups: groupsInfo, - SshEnabled: peer.SSHEnabled, - Hostname: peer.Meta.Hostname, - UserId: peer.UserID, - UiVersion: peer.Meta.UIVersion, - DnsLabel: fqdn(peer, dnsDomain), - LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, - LoginExpired: peer.Status.LoginExpired, - ApprovalRequired: !approved, - CountryCode: peer.Location.CountryCode, - CityName: peer.Location.CityName, - SerialNumber: peer.Meta.SystemSerialNumber, + Id: peer.ID, + Name: peer.Name, + Ip: peer.IP.String(), + ConnectionIp: peer.Location.ConnectionIP.String(), + Connected: peer.Status.Connected, + LastSeen: peer.Status.LastSeen, + Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion), + KernelVersion: peer.Meta.KernelVersion, + GeonameId: int(peer.Location.GeoNameID), + Version: peer.Meta.WtVersion, + Groups: groupsInfo, + SshEnabled: peer.SSHEnabled, + Hostname: peer.Meta.Hostname, + UserId: peer.UserID, + UiVersion: peer.Meta.UIVersion, + DnsLabel: fqdn(peer, dnsDomain), + LoginExpirationEnabled: peer.LoginExpirationEnabled, + LastLogin: peer.LastLogin, + LoginExpired: peer.Status.LoginExpired, + ApprovalRequired: !approved, + CountryCode: peer.Location.CountryCode, + CityName: peer.Location.CityName, + SerialNumber: peer.Meta.SystemSerialNumber, + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } @@ -387,6 +391,8 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, SerialNumber: peer.Meta.SystemSerialNumber, + + InactivityExpirationEnabled: peer.InactivityExpirationEnabled, } } diff --git a/management/server/peer.go b/management/server/peer.go index a7d4f3b06aa..a4c7e126675 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -110,6 +110,31 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } + expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) + if err != nil { + return err + } + + if peer.AddedWithSSOLogin() { + if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + am.checkAndSchedulePeerLoginExpiration(ctx, account) + } + + if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + + if expired { + // we need to update other peers because when peer login expires all other peers are notified to disconnect from + // the expired one. Here we notify them that connection is now allowed again. + am.updateAccountPeers(ctx, account) + } + + return nil +} + +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -138,25 +163,15 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK account.UpdatePeer(peer) - err = am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { - return err + return false, err } - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) - } - - if oldStatus.LoginExpired { - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) - } - - return nil + return oldStatus.LoginExpired, nil } -// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. +// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -219,6 +234,25 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { + + if !peer.AddedWithSSOLogin() { + return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") + } + + peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled + + event := activity.PeerInactivityExpirationEnabled + if !update.InactivityExpirationEnabled { + event = activity.PeerInactivityExpirationDisabled + } + am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) + + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + } + account.UpdatePeer(peer) err = am.Store.SaveAccount(ctx, account) @@ -442,23 +476,24 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s registrationTime := time.Now().UTC() newPeer = &nbpeer.Peer{ - ID: xid.New().String(), - AccountID: accountID, - Key: peer.Key, - SetupKey: upperKey, - IP: freeIP, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: freeLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + InactivityExpirationEnabled: addedByUser, } opEvent.TargetID = newPeer.ID opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 3d9ba18e9e5..9a53459a8c8 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -38,6 +38,8 @@ type Peer struct { // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin LoginExpirationEnabled bool + + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation LastLogin time.Time // CreatedAt records the time the peer was created @@ -187,6 +189,8 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, + + InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } @@ -219,6 +223,22 @@ func (p *Peer) MarkLoginExpired(expired bool) { p.Status = newStatus } +// SessionExpired indicates whether the peer's session has expired or not. +// If Peer.LastLogin plus the expiresIn duration has happened already; then session has expired. +// Return true if a session has expired, false otherwise, and time left to expiration (negative when expired). +// Session expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property. +// Session expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled. +// Only peers added by interactive SSO login can be expired. +func (p *Peer) SessionExpired(expiresIn time.Duration) (bool, time.Duration) { + if !p.AddedWithSSOLogin() || !p.InactivityExpirationEnabled || p.Status.Connected { + return false, 0 + } + expiresAt := p.Status.LastSeen.Add(expiresIn) + now := time.Now() + timeLeft := expiresAt.Sub(now) + return timeLeft <= 0, timeLeft +} + // LoginExpired indicates whether the peer's login has expired or not. // If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired. // Return true if a login has expired, false otherwise, and time left to expiration (negative when expired). diff --git a/management/server/peer_test.go b/management/server/peer_test.go index f3bf0ddba78..c5edb5636ad 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -82,6 +82,68 @@ func TestPeer_LoginExpired(t *testing.T) { } } +func TestPeer_SessionExpired(t *testing.T) { + tt := []struct { + name string + expirationEnabled bool + lastLogin time.Time + connected bool + expected bool + accountSettings *Settings + }{ + { + name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", + expirationEnabled: false, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Hour, + }, + expected: false, + }, + { + name: "Peer Inactivity Should Expire", + expirationEnabled: true, + connected: false, + lastLogin: time.Now().UTC().Add(-1 * time.Second), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: true, + }, + { + name: "Peer Inactivity Should Not Expire", + expirationEnabled: true, + connected: true, + lastLogin: time.Now().UTC(), + accountSettings: &Settings{ + PeerInactivityExpirationEnabled: true, + PeerInactivityExpiration: time.Second, + }, + expected: false, + }, + } + + for _, c := range tt { + t.Run(c.name, func(t *testing.T) { + peerStatus := &nbpeer.PeerStatus{ + Connected: c.connected, + } + peer := &nbpeer.Peer{ + InactivityExpirationEnabled: c.expirationEnabled, + LastLogin: c.lastLogin, + Status: peerStatus, + UserID: userID, + } + + expired, _ := peer.SessionExpired(c.accountSettings.PeerInactivityExpiration) + assert.Equal(t, expired, c.expected) + }) + } +} + func TestAccountManager_GetNetworkMap(t *testing.T) { manager, err := createManager(t) if err != nil { From cee95461d15fac87ba01ddb63f0715099e985d4f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 15 Oct 2024 15:03:17 +0200 Subject: [PATCH 44/81] [client] Add universal bin build and update sign workflow version (#2738) * Add universal binaries build for macOS * update sign pipeline version * handle info.plist in sign workflow --- .github/workflows/release.yml | 4 ++-- .goreleaser.yaml | 3 +++ .goreleaser_ui_darwin.yaml | 3 +++ client/ui/Info.plist | 12 ------------ 4 files changed, 8 insertions(+), 14 deletions(-) delete mode 100644 client/ui/Info.plist diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b2e2437e6bb..1b85ec7efd2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.14" + SIGN_PIPE_VER: "v0.0.15" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -223,4 +223,4 @@ jobs: repo: netbirdio/sign-pipelines ref: ${{ env.SIGN_PIPE_VER }} token: ${{ secrets.SIGN_GITHUB_TOKEN }} - inputs: '{ "tag": "${{ github.ref }}" }' + inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' diff --git a/.goreleaser.yaml b/.goreleaser.yaml index cf2ce4f4f0d..e718b3fcd1a 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -96,6 +96,9 @@ builds: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser mod_timestamp: "{{ .CommitTimestamp }}" +universal_binaries: + - id: netbird + archives: - builds: - netbird diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index bccb7f4717a..0a008207587 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -23,6 +23,9 @@ builds: tags: - load_wgnt_from_rsrc +universal_binaries: + - id: netbird-ui-darwin + archives: - builds: - netbird-ui-darwin diff --git a/client/ui/Info.plist b/client/ui/Info.plist deleted file mode 100644 index 8441110b921..00000000000 --- a/client/ui/Info.plist +++ /dev/null @@ -1,12 +0,0 @@ - - - - - CFBundleExecutable - netbird-ui - CFBundleIconFile - Netbird - LSUIElement - 1 - - From 8c8900be57b76e40bedcb4c6c56d4a57d2afd9bf Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 16 Oct 2024 17:35:59 +0200 Subject: [PATCH 45/81] [client] Exclude loopback from NAT (#2747) --- client/firewall/iptables/router_linux.go | 4 +++- client/firewall/nftables/router_linux.go | 15 +++++++++++++++ client/firewall/nftables/router_linux_test.go | 12 ++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e60c352d5c1..12932392871 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -433,10 +433,12 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" + lointdir := "-o" if inverse { intdir = "-o" + lointdir = "-i" } - return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} + return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump} } func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 404ba695780..03526fee7b9 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -425,11 +425,15 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { destExp := generateCIDRMatcherExpressions(false, pair.Destination) dir := expr.MetaKeyIIFNAME + notDir := expr.MetaKeyOIFNAME if pair.Inverse { dir = expr.MetaKeyOIFNAME + notDir = expr.MetaKeyIIFNAME } + lo := ifname("lo") intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ &expr.Meta{ Key: dir, @@ -440,6 +444,17 @@ func (r *router) addNatRule(pair firewall.RouterPair) error { Register: 1, Data: intf, }, + + // We need to exclude the loopback interface as this changes the ebpf proxy port + &expr.Meta{ + Key: notDir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: lo, + }, } exprs = append(exprs, sourceExp...) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 25b7587ac67..c07111b4e10 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -69,6 +69,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) { Register: 1, Data: ifname(ifaceMock.Name()), }, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, ) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) @@ -97,6 +103,12 @@ func TestNftablesManager_AddNatRule(t *testing.T) { Register: 1, Data: ifname(ifaceMock.Name()), }, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: ifname("lo"), + }, ) inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) From f942491b91d8d4627402512f5ec8ff5054a570f7 Mon Sep 17 00:00:00 2001 From: Emre Oksum Date: Wed, 16 Oct 2024 18:51:21 +0300 Subject: [PATCH 46/81] Update Zitadel version on quickstart script (#2744) Update Zitadel version at docker compose in quickstart script from 2.54.3 to 2.54.10 because 2.54.3 isn't stable and has a lot of bugs. --- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2c5c35d5302..16b2364fb56 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -873,7 +873,7 @@ services: zitadel: restart: 'always' networks: [netbird] - image: 'ghcr.io/zitadel/zitadel:v2.54.3' + image: 'ghcr.io/zitadel/zitadel:v2.54.10' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' env_file: - ./zitadel.env From 96d22076849027e7b8179feabbdd9892d600eb5a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 16 Oct 2024 18:55:30 +0300 Subject: [PATCH 47/81] Fix JSON function compatibility for SQLite and PostgreSQL (#2746) resolves the issue with json_array_length compatibility between SQLite and PostgreSQL. It adjusts the query to conditionally cast types: PostgreSQL: Casts to json with ::json. SQLite: Uses the text representation directly. --- management/server/sql_store.go | 12 ++++++++++-- management/server/sql_store_test.go | 13 +++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index de3dfa9455e..47395f51109 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1154,8 +1154,16 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group - result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). - Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + // TODO: This fix is accepted for now, but if we need to handle this more frequently + // we may need to reconsider changing the types. + query := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) + if s.storeEngine == PostgresStoreEngine { + query = query.Order("json_array_length(peers::json) DESC") + } else { + query = query.Order("json_array_length(peers) DESC") + } + + result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20e812ea709..000eb1b11b2 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1251,3 +1251,16 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { }) } + +func TestSqlite_GetGroupByName(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + if err != nil { + t.Fatal(err) + } + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + require.NoError(t, err) + require.Equal(t, "All", group.Name) +} From ccd4ae6315853249002de2cb6d31dea6a3d330f4 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 17 Oct 2024 19:21:35 +0200 Subject: [PATCH 48/81] Fix domain information is up to date check (#2754) --- management/server/account.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 7c84ad1ca1b..4c4806bb52b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -45,15 +45,15 @@ import ( ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -1440,7 +1440,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, u return err } cachedAccount := &Account{ - Id: accountID, + Id: accountID, Users: make(map[string]*User), } for _, user := range accountUsers { @@ -2276,7 +2276,7 @@ func handleNotFound(err error) error { } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return claims.Domain != "" && claims.Domain != domain && claims.DomainCategory == PrivateCategory && domainCategory != PrivateCategory + return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { From 507a40bd7f7a9119fd9ddbfe72ef84160b864dae Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 17 Oct 2024 20:39:59 +0200 Subject: [PATCH 49/81] Fix decompress zip path (#2755) Since 0.30.2 the decompressed binary path from the signed package has changed now it doesn't contain the arch suffix this change handles that --- client/ui/netbird-ui.rb.tmpl | 6 +++--- release_files/install.sh | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/ui/netbird-ui.rb.tmpl b/client/ui/netbird-ui.rb.tmpl index 9efddd750b0..06971909d4c 100644 --- a/client/ui/netbird-ui.rb.tmpl +++ b/client/ui/netbird-ui.rb.tmpl @@ -8,11 +8,11 @@ cask "{{ $projectName }}" do if Hardware::CPU.intel? url "{{ $amdURL }}" sha256 "{{ crypto.SHA256 $amdFileBytes }}" - app "netbird_ui_darwin_amd64", target: "Netbird UI.app" + app "netbird_ui_darwin", target: "Netbird UI.app" else url "{{ $armURL }}" sha256 "{{ crypto.SHA256 $armFileBytes }}" - app "netbird_ui_darwin_arm64", target: "Netbird UI.app" + app "netbird_ui_darwin", target: "Netbird UI.app" end depends_on formula: "netbird" @@ -36,4 +36,4 @@ cask "{{ $projectName }}" do name "Netbird UI" desc "Netbird UI Client" homepage "https://www.netbird.io/" -end \ No newline at end of file +end diff --git a/release_files/install.sh b/release_files/install.sh index b7a6c08f9a7..b0fec27339c 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -86,7 +86,7 @@ download_release_binary() { # Unzip the app and move to INSTALL_DIR unzip -q -o "$BINARY_NAME" - mv "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" + mv -v "netbird_ui_${OS_TYPE}/" "$INSTALL_DIR/" || mv -v "netbird_ui_${OS_TYPE}_${ARCH}/" "$INSTALL_DIR/" else ${SUDO} mkdir -p "$INSTALL_DIR" tar -xzvf "$BINARY_NAME" From c8d8748dcf051b647f751511a385c180556ee929 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 18 Oct 2024 17:28:58 +0200 Subject: [PATCH 50/81] Update sign workflow version (#2756) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1b85ec7efd2..14e383a27c5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.15" + SIGN_PIPE_VER: "v0.0.16" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" From 88e4fc2245e5490b40678e51998e08bfd49adad9 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 19 Oct 2024 18:32:17 +0200 Subject: [PATCH 51/81] Release global lock on early error (#2760) --- management/server/account.go | 1 + 1 file changed, 1 insertion(+) diff --git a/management/server/account.go b/management/server/account.go index 4c4806bb52b..cca3b4e52df 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2220,6 +2220,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont // check again if the domain has a primary account because of simultaneous requests domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if handleNotFound(err) != nil { + cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", nil, err } From 9929b22afcc05367ae4217ef1a8d95dc9f645a51 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 21 Oct 2024 14:39:28 +0200 Subject: [PATCH 52/81] Replace suite tests with regular go tests (#2762) * Replace file suite tests with go tests * Replace file suite tests with go tests --- util/file_suite_test.go | 126 -------------------------------------- util/file_test.go | 130 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 126 deletions(-) delete mode 100644 util/file_suite_test.go diff --git a/util/file_suite_test.go b/util/file_suite_test.go deleted file mode 100644 index 3de7db49bdd..00000000000 --- a/util/file_suite_test.go +++ /dev/null @@ -1,126 +0,0 @@ -package util_test - -import ( - "crypto/md5" - "encoding/hex" - "io" - "os" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - - "github.com/netbirdio/netbird/util" -) - -var _ = Describe("Client", func() { - - var ( - tmpDir string - ) - - type TestConfig struct { - SomeMap map[string]string - SomeArray []string - SomeField int - } - - BeforeEach(func() { - var err error - tmpDir, err = os.MkdirTemp("", "wiretrustee_util_test_tmp_*") - Expect(err).NotTo(HaveOccurred()) - }) - - AfterEach(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) - - Describe("Config", func() { - Context("in JSON format", func() { - It("should be written and read successfully", func() { - - m := make(map[string]string) - m["key1"] = "value1" - m["key2"] = "value2" - - arr := []string{"value1", "value2"} - - written := &TestConfig{ - SomeMap: m, - SomeArray: arr, - SomeField: 99, - } - - err := util.WriteJson(tmpDir+"/testconfig.json", written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - Expect(read.(*TestConfig).SomeMap["key1"]).To(BeEquivalentTo(written.SomeMap["key1"])) - Expect(read.(*TestConfig).SomeMap["key2"]).To(BeEquivalentTo(written.SomeMap["key2"])) - Expect(read.(*TestConfig).SomeArray).To(ContainElements(arr)) - Expect(read.(*TestConfig).SomeField).To(BeEquivalentTo(written.SomeField)) - - }) - }) - }) - - Describe("Copying file contents", func() { - Context("from one file to another", func() { - It("should be successful", func() { - - src := tmpDir + "/copytest_src" - dst := tmpDir + "/copytest_dst" - - err := util.WriteJson(src, []string{"1", "2", "3"}) - Expect(err).NotTo(HaveOccurred()) - - err = util.CopyFileContents(src, dst) - Expect(err).NotTo(HaveOccurred()) - - hashSrc := md5.New() - hashDst := md5.New() - - srcFile, err := os.Open(src) - Expect(err).NotTo(HaveOccurred()) - - dstFile, err := os.Open(dst) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashSrc, srcFile) - Expect(err).NotTo(HaveOccurred()) - - _, err = io.Copy(hashDst, dstFile) - Expect(err).NotTo(HaveOccurred()) - - err = srcFile.Close() - Expect(err).NotTo(HaveOccurred()) - - err = dstFile.Close() - Expect(err).NotTo(HaveOccurred()) - - Expect(hex.EncodeToString(hashSrc.Sum(nil)[:16])).To(BeEquivalentTo(hex.EncodeToString(hashDst.Sum(nil)[:16]))) - }) - }) - }) - - Describe("Handle config file without full path", func() { - Context("config file handling", func() { - It("should be successful", func() { - written := &TestConfig{ - SomeField: 123, - } - cfgFile := "test_cfg.json" - defer os.Remove(cfgFile) - - err := util.WriteJson(cfgFile, written) - Expect(err).NotTo(HaveOccurred()) - - read, err := util.ReadJson(cfgFile, &TestConfig{}) - Expect(err).NotTo(HaveOccurred()) - Expect(read).NotTo(BeNil()) - }) - }) - }) -}) diff --git a/util/file_test.go b/util/file_test.go index 1330e738e8d..566d8eda6fb 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,12 +1,142 @@ package util import ( + "crypto/md5" + "encoding/hex" + "io" "os" "reflect" "strings" "testing" + + "github.com/stretchr/testify/require" ) +type TestConfig struct { + SomeMap map[string]string + SomeArray []string + SomeField int +} + +func TestConfigJSON(t *testing.T) { + tests := []struct { + name string + config *TestConfig + expectedError bool + }{ + { + name: "Valid JSON config", + config: &TestConfig{ + SomeMap: map[string]string{"key1": "value1", "key2": "value2"}, + SomeArray: []string{"value1", "value2"}, + SomeField: 99, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + err := WriteJson(tmpDir+"/testconfig.json", tt.config) + require.NoError(t, err) + + read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) + require.NoError(t, err) + require.NotNil(t, read) + require.Equal(t, tt.config.SomeMap["key1"], read.(*TestConfig).SomeMap["key1"]) + require.Equal(t, tt.config.SomeMap["key2"], read.(*TestConfig).SomeMap["key2"]) + require.ElementsMatch(t, tt.config.SomeArray, read.(*TestConfig).SomeArray) + require.Equal(t, tt.config.SomeField, read.(*TestConfig).SomeField) + }) + } +} + +func TestCopyFileContents(t *testing.T) { + tests := []struct { + name string + srcContent []string + expectedError bool + }{ + { + name: "Copy file contents successfully", + srcContent: []string{"1", "2", "3"}, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + + src := tmpDir + "/copytest_src" + dst := tmpDir + "/copytest_dst" + + err := WriteJson(src, tt.srcContent) + require.NoError(t, err) + + err = CopyFileContents(src, dst) + require.NoError(t, err) + + hashSrc := md5.New() + hashDst := md5.New() + + srcFile, err := os.Open(src) + require.NoError(t, err) + defer func() { + _ = srcFile.Close() + }() + + dstFile, err := os.Open(dst) + require.NoError(t, err) + defer func() { + _ = dstFile.Close() + }() + + _, err = io.Copy(hashSrc, srcFile) + require.NoError(t, err) + + _, err = io.Copy(hashDst, dstFile) + require.NoError(t, err) + + require.Equal(t, hex.EncodeToString(hashSrc.Sum(nil)[:16]), hex.EncodeToString(hashDst.Sum(nil)[:16])) + }) + } +} + +func TestHandleConfigFileWithoutFullPath(t *testing.T) { + tests := []struct { + name string + config *TestConfig + expectedError bool + }{ + { + name: "Handle config file without full path", + config: &TestConfig{ + SomeField: 123, + }, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgFile := "test_cfg.json" + defer func() { + _ = os.Remove(cfgFile) + }() + + err := WriteJson(cfgFile, tt.config) + require.NoError(t, err) + + read, err := ReadJson(cfgFile, &TestConfig{}) + require.NoError(t, err) + require.NotNil(t, read) + }) + } +} + func TestReadJsonWithEnvSub(t *testing.T) { type Config struct { CertFile string `json:"CertFile"` From 0106a95f7a28e2135d994e695ec758f7bcb40f2d Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 22 Oct 2024 13:29:17 +0300 Subject: [PATCH 53/81] lock account and use transaction (#2767) Signed-off-by: bcmmbaga --- management/server/account.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index cca3b4e52df..b49b82f9128 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2012,10 +2012,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) - unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer func() { - if unlockPeer != nil { - unlockPeer() + if unlockAccount != nil { + unlockAccount() } }() @@ -2024,12 +2024,12 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var hasChanges bool var user *User err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := am.Store.GetAccountGroups(ctx, accountID) + groups, err := transaction.GetAccountGroups(ctx, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2087,8 +2087,8 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error incrementing network serial: %w", err) } } - unlockPeer() - unlockPeer = nil + unlockAccount() + unlockAccount = nil return nil }) From 30ebcf38c7bb351cda7f062e588423e3bf6f64ff Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 22 Oct 2024 20:53:14 +0200 Subject: [PATCH 54/81] [client] Eliminate UDP proxy in user-space mode (#2712) In the case of user space WireGuard mode, use in-memory proxy between the TURN/Relay connection and the WireGuard Bind. We keep the UDP proxy and eBPF proxy for kernel mode. The key change is the new wgproxy/bind and the iface/bind/ice_bind changes. Everything else is just to fulfill the dependencies. --- client/iface/bind/bind.go | 142 --------- client/iface/bind/endpoint.go | 5 + client/iface/bind/ice_bind.go | 276 ++++++++++++++++++ client/iface/device/device_android.go | 5 +- client/iface/device/device_darwin.go | 5 +- client/iface/device/device_ios.go | 5 +- client/iface/device/device_netstack.go | 5 +- client/iface/device/device_usp_unix.go | 6 +- client/iface/device/device_windows.go | 5 +- client/iface/iface.go | 51 +++- client/iface/iface_android.go | 43 --- client/iface/iface_create.go | 7 + client/iface/iface_create_android.go | 24 ++ ...iface_darwin.go => iface_create_darwin.go} | 36 +-- client/iface/iface_guid_windows.go | 10 + client/iface/iface_ios.go | 31 -- client/iface/iface_moc.go | 7 + client/iface/iface_new_android.go | 24 ++ client/iface/iface_new_darwin.go | 34 +++ client/iface/iface_new_ios.go | 26 ++ client/iface/iface_new_unix.go | 45 +++ client/iface/iface_new_windows.go | 32 ++ client/iface/iface_test.go | 132 ++++++++- client/iface/iface_unix.go | 49 ---- client/iface/iface_windows.go | 41 --- client/iface/iwginterface.go | 2 + client/iface/iwginterface_windows.go | 2 + client/iface/wgproxy/bind/proxy.go | 137 +++++++++ .../wgproxy/ebpf/portlookup.go | 0 .../wgproxy/ebpf/portlookup_test.go | 0 .../{internal => iface}/wgproxy/ebpf/proxy.go | 2 +- .../wgproxy/ebpf/proxy_test.go | 0 .../wgproxy/ebpf/wrapper.go | 2 +- client/iface/wgproxy/factory_kernel.go | 47 +++ .../iface/wgproxy/factory_kernel_freebsd.go | 26 ++ client/iface/wgproxy/factory_usp.go | 27 ++ client/iface/wgproxy/proxy.go | 15 + client/iface/wgproxy/proxy_linux_test.go | 56 ++++ .../{internal => iface}/wgproxy/proxy_test.go | 8 +- .../usp => iface/wgproxy/udp}/proxy.go | 28 +- client/internal/dns/server_test.go | 34 ++- client/internal/engine.go | 33 +-- client/internal/engine_test.go | 21 +- client/internal/peer/conn.go | 46 +-- client/internal/peer/conn_test.go | 25 +- client/internal/routemanager/manager_test.go | 10 +- .../systemops/systemops_generic_test.go | 29 +- client/internal/wgproxy/factory_linux.go | 50 ---- client/internal/wgproxy/factory_nonlinux.go | 21 -- client/internal/wgproxy/proxy.go | 15 - 50 files changed, 1129 insertions(+), 553 deletions(-) delete mode 100644 client/iface/bind/bind.go create mode 100644 client/iface/bind/endpoint.go create mode 100644 client/iface/bind/ice_bind.go delete mode 100644 client/iface/iface_android.go create mode 100644 client/iface/iface_create_android.go rename client/iface/{iface_darwin.go => iface_create_darwin.go} (50%) create mode 100644 client/iface/iface_guid_windows.go delete mode 100644 client/iface/iface_ios.go create mode 100644 client/iface/iface_new_android.go create mode 100644 client/iface/iface_new_darwin.go create mode 100644 client/iface/iface_new_ios.go create mode 100644 client/iface/iface_new_unix.go create mode 100644 client/iface/iface_new_windows.go delete mode 100644 client/iface/iface_unix.go delete mode 100644 client/iface/iface_windows.go create mode 100644 client/iface/wgproxy/bind/proxy.go rename client/{internal => iface}/wgproxy/ebpf/portlookup.go (100%) rename client/{internal => iface}/wgproxy/ebpf/portlookup_test.go (100%) rename client/{internal => iface}/wgproxy/ebpf/proxy.go (99%) rename client/{internal => iface}/wgproxy/ebpf/proxy_test.go (100%) rename client/{internal => iface}/wgproxy/ebpf/wrapper.go (95%) create mode 100644 client/iface/wgproxy/factory_kernel.go create mode 100644 client/iface/wgproxy/factory_kernel_freebsd.go create mode 100644 client/iface/wgproxy/factory_usp.go create mode 100644 client/iface/wgproxy/proxy.go create mode 100644 client/iface/wgproxy/proxy_linux_test.go rename client/{internal => iface}/wgproxy/proxy_test.go (90%) rename client/{internal/wgproxy/usp => iface/wgproxy/udp}/proxy.go (83%) delete mode 100644 client/internal/wgproxy/factory_linux.go delete mode 100644 client/internal/wgproxy/factory_nonlinux.go delete mode 100644 client/internal/wgproxy/proxy.go diff --git a/client/iface/bind/bind.go b/client/iface/bind/bind.go deleted file mode 100644 index ba6153cb738..00000000000 --- a/client/iface/bind/bind.go +++ /dev/null @@ -1,142 +0,0 @@ -package bind - -import ( - "fmt" - "net" - "runtime" - "sync" - - "github.com/pion/stun/v2" - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" - "golang.org/x/net/ipv4" - wgConn "golang.zx2c4.com/wireguard/conn" -) - -type receiverCreator struct { - iceBind *ICEBind -} - -func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { - return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) -} - -type ICEBind struct { - *wgConn.StdNetBind - - muUDPMux sync.Mutex - - transportNet transport.Net - udpMux *UniversalUDPMuxDefault - - filterFn FilterFn -} - -func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { - ib := &ICEBind{ - transportNet: transportNet, - filterFn: filterFn, - } - - rc := receiverCreator{ - ib, - } - ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) - return ib -} - -// GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { - s.muUDPMux.Lock() - defer s.muUDPMux.Unlock() - if s.udpMux == nil { - return nil, fmt.Errorf("ICEBind has not been initialized yet") - } - - return s.udpMux, nil -} - -func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { - s.muUDPMux.Lock() - defer s.muUDPMux.Unlock() - - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ - UDPConn: conn, - Net: s.transportNet, - FilterFn: s.filterFn, - }, - ) - return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) - defer ipv4MsgsPool.Put(msgs) - for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] - } - var numMsgs int - if runtime.GOOS == "linux" { - numMsgs, err = pc.ReadBatch(*msgs, 0) - if err != nil { - return 0, err - } - } else { - msg := &(*msgs)[0] - msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) - if err != nil { - return 0, err - } - numMsgs = 1 - } - for i := 0; i < numMsgs; i++ { - msg := &(*msgs)[i] - - // todo: handle err - ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) - if ok { - sizes[i] = 0 - } else { - sizes[i] = msg.N - } - - addrPort := msg.Addr.(*net.UDPAddr).AddrPort() - ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation - wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) - eps[i] = ep - } - return numMsgs, nil - } -} - -func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { - for i := range buffers { - if !stun.IsMessage(buffers[i]) { - continue - } - - msg, err := s.parseSTUNMessage(buffers[i][:n]) - if err != nil { - buffers[i] = []byte{} - return true, err - } - - muxErr := s.udpMux.HandleSTUNMessage(msg, addr) - if muxErr != nil { - log.Warnf("failed to handle STUN packet") - } - - buffers[i] = []byte{} - return true, nil - } - return false, nil -} - -func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) { - msg := &stun.Message{ - Raw: raw, - } - if err := msg.Decode(); err != nil { - return nil, err - } - - return msg, nil -} diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go new file mode 100644 index 00000000000..1926ff88f1d --- /dev/null +++ b/client/iface/bind/endpoint.go @@ -0,0 +1,5 @@ +package bind + +import wgConn "golang.zx2c4.com/wireguard/conn" + +type Endpoint = wgConn.StdNetEndpoint diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go new file mode 100644 index 00000000000..ccdcc2cda30 --- /dev/null +++ b/client/iface/bind/ice_bind.go @@ -0,0 +1,276 @@ +package bind + +import ( + "fmt" + "net" + "net/netip" + "runtime" + "strings" + "sync" + + "github.com/pion/stun/v2" + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" + "golang.org/x/net/ipv4" + wgConn "golang.zx2c4.com/wireguard/conn" +) + +type RecvMessage struct { + Endpoint *Endpoint + Buffer []byte +} + +type receiverCreator struct { + iceBind *ICEBind +} + +func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { + return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn) +} + +// ICEBind is a bind implementation with two main features: +// 1. filter out STUN messages and handle them +// 2. forward the received packets to the WireGuard interface from the relayed connection +// +// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address +// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to +// use the port because in the Send function the wgConn.Endpoint the port info is not exported. +type ICEBind struct { + *wgConn.StdNetBind + RecvChan chan RecvMessage + + transportNet transport.Net + filterFn FilterFn + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a + // new closed channel. With the closedChanMu we can safely close the channel and create a new one + closedChan chan struct{} + closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. + closed bool + + muUDPMux sync.Mutex + udpMux *UniversalUDPMuxDefault +} + +func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { + b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) + ib := &ICEBind{ + StdNetBind: b, + RecvChan: make(chan RecvMessage, 1), + transportNet: transportNet, + filterFn: filterFn, + endpoints: make(map[netip.Addr]net.Conn), + closedChan: make(chan struct{}), + closed: true, + } + + rc := receiverCreator{ + ib, + } + ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc) + return ib +} + +func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + s.closed = false + s.closedChanMu.Lock() + s.closedChan = make(chan struct{}) + s.closedChanMu.Unlock() + fns, port, err := s.StdNetBind.Open(uport) + if err != nil { + return nil, 0, err + } + fns = append(fns, s.receiveRelayed) + return fns, port, nil +} + +func (s *ICEBind) Close() error { + if s.closed { + return nil + } + s.closed = true + + close(s.closedChan) + + return s.StdNetBind.Close() +} + +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { + s.muUDPMux.Lock() + defer s.muUDPMux.Unlock() + if s.udpMux == nil { + return nil, fmt.Errorf("ICEBind has not been initialized yet") + } + + return s.udpMux, nil +} + +func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { + fakeUDPAddr, err := fakeAddress(peerAddress) + if err != nil { + return nil, err + } + + // force IPv4 + fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4()) + if !ok { + return nil, fmt.Errorf("failed to convert IP to netip.Addr") + } + + b.endpointsMu.Lock() + b.endpoints[fakeAddr] = conn + b.endpointsMu.Unlock() + + return fakeUDPAddr, nil +} + +func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) { + fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4()) + if !ok { + log.Warnf("failed to convert IP to netip.Addr") + return + } + + b.endpointsMu.Lock() + defer b.endpointsMu.Unlock() + delete(b.endpoints, fakeAddr) +} + +func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { + b.endpointsMu.Lock() + conn, ok := b.endpoints[ep.DstIP()] + b.endpointsMu.Unlock() + if !ok { + log.Infof("failed to find endpoint for %s", ep.DstIP()) + return b.StdNetBind.Send(bufs, ep) + } + + for _, buf := range bufs { + if _, err := conn.Write(buf); err != nil { + return err + } + } + return nil +} + +func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc { + s.muUDPMux.Lock() + defer s.muUDPMux.Unlock() + + s.udpMux = NewUniversalUDPMuxDefault( + UniversalUDPMuxParams{ + UDPConn: conn, + Net: s.transportNet, + FilterFn: s.filterFn, + }, + ) + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + msgs := ipv4MsgsPool.Get().(*[]ipv4.Message) + defer ipv4MsgsPool.Put(msgs) + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + } + var numMsgs int + if runtime.GOOS == "linux" { + numMsgs, err = pc.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + + // todo: handle err + ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) + if ok { + sizes[i] = 0 + } else { + sizes[i] = msg.N + } + + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil + } +} + +func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { + for i := range buffers { + if !stun.IsMessage(buffers[i]) { + continue + } + + msg, err := s.parseSTUNMessage(buffers[i][:n]) + if err != nil { + buffers[i] = []byte{} + return true, err + } + + muxErr := s.udpMux.HandleSTUNMessage(msg, addr) + if muxErr != nil { + log.Warnf("failed to handle STUN packet") + } + + buffers[i] = []byte{} + return true, nil + } + return false, nil +} + +func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) { + msg := &stun.Message{ + Raw: raw, + } + if err := msg.Decode(); err != nil { + return nil, err + } + + return msg, nil +} + +// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the +// WireGuard. Critical part is do not block if the Closed() has been called. +func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + c.closedChanMu.RLock() + defer c.closedChanMu.RUnlock() + + select { + case <-c.closedChan: + return 0, net.ErrClosed + case msg, ok := <-c.RecvChan: + if !ok { + return 0, net.ErrClosed + } + copy(buffs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = wgConn.Endpoint(msg.Endpoint) + return 1, nil + } +} + +// fakeAddress returns a fake address that is used to as an identifier for the peer. +// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address. +func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) { + octets := strings.Split(peerAddress.IP.String(), ".") + if len(octets) != 4 { + return nil, fmt.Errorf("invalid IP format") + } + + newAddr := &net.UDPAddr{ + IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])), + Port: peerAddress.Port, + } + return newAddr, nil +} diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 29e3f409df6..fac2ba63df9 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -5,7 +5,6 @@ package device import ( "strings" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" @@ -31,13 +30,13 @@ type WGTunDevice struct { configurer WGConfigurer } -func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { +func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { return &WGTunDevice{ address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, tunAdapter: tunAdapter, } } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 03e85a7f17f..b5a128bc1cc 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -6,7 +6,6 @@ import ( "fmt" "os/exec" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -29,14 +28,14 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 226e8a2e0cb..b9591e0b8c6 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -6,7 +6,6 @@ package device import ( "os" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" @@ -30,13 +29,13 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, tunFd: tunFd, } } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 440a1ca191e..f5d39e9e074 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -6,7 +6,6 @@ package device import ( "fmt" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" @@ -31,7 +30,7 @@ type TunNetstackDevice struct { configurer WGConfigurer } -func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4175f65569e..643d77565c2 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -7,7 +7,6 @@ import ( "os" "runtime" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -30,7 +29,7 @@ type USPDevice struct { configurer WGConfigurer } -func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { log.Infof("using userspace bind mode") checkUser() @@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn)} + iceBind: iceBind, + } } func (t *USPDevice) Create() (WGConfigurer, error) { diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f3e216ccd5d..86968d06d7e 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -4,7 +4,6 @@ import ( "fmt" "net/netip" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device" @@ -32,14 +31,14 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), + iceBind: iceBind, } } diff --git a/client/iface/iface.go b/client/iface/iface.go index accf5ce0afb..1fb9c269179 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -6,12 +6,16 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) const ( @@ -22,14 +26,35 @@ const ( type WGAddress = device.WGAddress +type wgProxyFactory interface { + GetProxy() wgproxy.Proxy + Free() error +} + +type WGIFaceOpts struct { + IFaceName string + Address string + WGPort int + WGPrivKey string + MTU int + MobileArgs *device.MobileIFaceArguments + TransportNet transport.Net + FilterFn bind.FilterFn +} + // WGIface represents an interface instance type WGIface struct { tun WGTunDevice userspaceBind bool mu sync.Mutex - configurer device.WGConfigurer - filter device.PacketFilter + configurer device.WGConfigurer + filter device.PacketFilter + wgProxyFactory wgProxyFactory +} + +func (w *WGIface) GetProxy() wgproxy.Proxy { + return w.wgProxyFactory.GetProxy() } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -124,22 +149,26 @@ func (w *WGIface) Close() error { w.mu.Lock() defer w.mu.Unlock() - err := w.tun.Close() - if err != nil { - return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err) + var result *multierror.Error + + if err := w.wgProxyFactory.Free(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } - err = w.waitUntilRemoved() - if err != nil { + if err := w.tun.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) + } + + if err := w.waitUntilRemoved(); err != nil { log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) - err = w.Destroy() - if err != nil { - return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err) + if err := w.Destroy(); err != nil { + result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)) + return errors.FormatErrorOrNil(result) } log.Infof("interface %s successfully removed", w.Name()) } - return nil + return errors.FormatErrorOrNil(result) } // SetFilter sets packet filters for the userspace implementation diff --git a/client/iface/iface_android.go b/client/iface/iface_android.go deleted file mode 100644 index 5ed476e7060..00000000000 --- a/client/iface/iface_android.go +++ /dev/null @@ -1,43 +0,0 @@ -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), - userspaceBind: true, - } - return wgIFace, nil -} - -// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { - w.mu.Lock() - defer w.mu.Unlock() - - cfgr, err := w.tun.Create(routes, dns, searchDomains) - if err != nil { - return err - } - w.configurer = cfgr - return nil -} - -// Create this function make sense on mobile only -func (w *WGIface) Create() error { - return fmt.Errorf("this function has not implemented on this platform") -} diff --git a/client/iface/iface_create.go b/client/iface/iface_create.go index f389019ed73..5e17c6d4149 100644 --- a/client/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -2,6 +2,8 @@ package iface +import "fmt" + // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. // this function is different on Android @@ -17,3 +19,8 @@ func (w *WGIface) Create() error { w.configurer = cfgr return nil } + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on non mobile") +} diff --git a/client/iface/iface_create_android.go b/client/iface/iface_create_android.go new file mode 100644 index 00000000000..373a9c95a8b --- /dev/null +++ b/client/iface/iface_create_android.go @@ -0,0 +1,24 @@ +package iface + +import ( + "fmt" +) + +// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { + w.mu.Lock() + defer w.mu.Unlock() + + cfgr, err := w.tun.Create(routes, dns, searchDomains) + if err != nil { + return err + } + w.configurer = cfgr + return nil +} + +// Create this function make sense on mobile only +func (w *WGIface) Create() error { + return fmt.Errorf("this function has not implemented on this platform") +} diff --git a/client/iface/iface_darwin.go b/client/iface/iface_create_darwin.go similarity index 50% rename from client/iface/iface_darwin.go rename to client/iface/iface_create_darwin.go index b46ea0f8067..1d91bce54bd 100644 --- a/client/iface/iface_darwin.go +++ b/client/iface/iface_create_darwin.go @@ -7,39 +7,8 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" ) -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - userspaceBind: true, - } - - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - return wgIFace, nil - } - - wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) - - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on this platform") -} - // Create creates a new Wireguard interface, sets a given IP and brings it up. // Will reuse an existing one. // this function is different on Android @@ -65,3 +34,8 @@ func (w *WGIface) Create() error { return backoff.Retry(operation, backOff) } + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on this platform") +} diff --git a/client/iface/iface_guid_windows.go b/client/iface/iface_guid_windows.go new file mode 100644 index 00000000000..49492fd3d87 --- /dev/null +++ b/client/iface/iface_guid_windows.go @@ -0,0 +1,10 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/device" +) + +// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only +func (w *WGIface) GetInterfaceGUIDString() (string, error) { + return w.tun.(*device.TunDevice).GetInterfaceGUIDString() +} diff --git a/client/iface/iface_ios.go b/client/iface/iface_ios.go deleted file mode 100644 index fc0214748c1..00000000000 --- a/client/iface/iface_ios.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build ios - -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - wgIFace := &WGIface{ - tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), - userspaceBind: true, - } - return wgIFace, nil -} - -// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. -// Will reuse an existing one. -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on this platform") -} diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go index 703da9ce004..d91a7224ff2 100644 --- a/client/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type MockWGIface struct { @@ -30,6 +31,7 @@ type MockWGIface struct { GetDeviceFunc func() *device.FilteredDevice GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) + GetProxyFunc func() wgproxy.Proxy } func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { @@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice { func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } + +func (m *MockWGIface) GetProxy() wgproxy.Proxy { + //TODO implement me + panic("implement me") +} diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go new file mode 100644 index 00000000000..69a8d1fd4b5 --- /dev/null +++ b/client/iface/iface_new_android.go @@ -0,0 +1,24 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + wgIFace := &WGIface{ + userspaceBind: true, + tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil +} diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go new file mode 100644 index 00000000000..a92d74e0f90 --- /dev/null +++ b/client/iface/iface_new_darwin.go @@ -0,0 +1,34 @@ +//go:build !ios + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + var tun WGTunDevice + if netstack.IsEnabled() { + tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + } else { + tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + } + + wgIFace := &WGIface{ + userspaceBind: true, + tun: tun, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil +} diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go new file mode 100644 index 00000000000..363f95e1120 --- /dev/null +++ b/client/iface/iface_new_ios.go @@ -0,0 +1,26 @@ +//go:build ios + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + wgIFace := &WGIface{ + tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil +} diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_unix.go new file mode 100644 index 00000000000..f10b17c9a92 --- /dev/null +++ b/client/iface/iface_new_unix.go @@ -0,0 +1,45 @@ +//go:build (linux && !android) || freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet) + wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort) + return wgIFace, nil + } + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go new file mode 100644 index 00000000000..2e635549602 --- /dev/null +++ b/client/iface/iface_new_windows.go @@ -0,0 +1,32 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + + var tun WGTunDevice + if netstack.IsEnabled() { + tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + } else { + tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + } + + wgIFace := &WGIface{ + userspaceBind: true, + tun: tun, + wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + } + return wgIFace, nil + +} diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index 87a68addbfc..85db9cacb8d 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -45,7 +45,16 @@ func TestWGIface_UpdateAddr(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: addr, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -118,7 +127,16 @@ func Test_CreateInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: 33100, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -153,7 +171,16 @@ func Test_Close(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -189,7 +216,16 @@ func TestRecreation(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -252,7 +288,15 @@ func Test_ConfigureInterface(t *testing.T) { if err != nil { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: wgPort, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -300,7 +344,16 @@ func Test_UpdatePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: 33100, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -361,7 +414,16 @@ func Test_RemovePeer(t *testing.T) { t.Fatal(err) } - iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) + opts := WGIFaceOpts{ + IFaceName: ifaceName, + Address: wgIP, + WGPort: 33100, + WGPrivKey: key, + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface, err := NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -418,7 +480,15 @@ func Test_ConnectPeers(t *testing.T) { guid := fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) + optsPeer1 := WGIFaceOpts{ + IFaceName: peer1ifaceName, + Address: peer1wgIP, + WGPort: peer1wgPort, + WGPrivKey: peer1Key.String(), + MTU: DefaultMTU, + TransportNet: newNet, + } + iface1, err := NewWGIFace(optsPeer1) if err != nil { t.Fatal(err) } @@ -432,7 +502,12 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort)) + localIP, err := getLocalIP() + if err != nil { + t.Fatal(err) + } + + peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort)) if err != nil { t.Fatal(err) } @@ -444,7 +519,17 @@ func Test_ConnectPeers(t *testing.T) { if err != nil { t.Fatal(err) } - iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil) + + optsPeer2 := WGIFaceOpts{ + IFaceName: peer2ifaceName, + Address: peer2wgIP, + WGPort: peer2wgPort, + WGPrivKey: peer2Key.String(), + MTU: DefaultMTU, + TransportNet: newNet, + } + + iface2, err := NewWGIFace(optsPeer2) if err != nil { t.Fatal(err) } @@ -458,7 +543,7 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } - peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort)) + peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort)) if err != nil { t.Fatal(err) } @@ -527,3 +612,28 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { } return wgtypes.Peer{}, fmt.Errorf("peer not found") } + +func getLocalIP() (string, error) { + // Get all interfaces + addrs, err := net.InterfaceAddrs() + if err != nil { + return "", err + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + if ipNet.IP.IsLoopback() { + continue + } + + if ipNet.IP.To4() == nil { + continue + } + return ipNet.IP.String(), nil + } + + return "", fmt.Errorf("no local IP found") +} diff --git a/client/iface/iface_unix.go b/client/iface/iface_unix.go deleted file mode 100644 index 09dbb2c1f7d..00000000000 --- a/client/iface/iface_unix.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build (linux && !android) || freebsd - -package iface - -import ( - "fmt" - "runtime" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{} - - // move the kernel/usp/netstack preference evaluation to upper layer - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - wgIFace.userspaceBind = true - return wgIFace, nil - } - - if device.WireGuardModuleIsLoaded() { - wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) - wgIFace.userspaceBind = false - return wgIFace, nil - } - - if !device.ModuleTunIsLoaded() { - return nil, fmt.Errorf("couldn't check or load tun module") - } - wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) - wgIFace.userspaceBind = true - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS) -} diff --git a/client/iface/iface_windows.go b/client/iface/iface_windows.go deleted file mode 100644 index 6845ef3ddd6..00000000000 --- a/client/iface/iface_windows.go +++ /dev/null @@ -1,41 +0,0 @@ -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/client/iface/bind" - "github.com/netbirdio/netbird/client/iface/device" - "github.com/netbirdio/netbird/client/iface/netstack" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - userspaceBind: true, - } - - if netstack.IsEnabled() { - wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - return wgIFace, nil - } - - wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on non mobile") -} - -// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only -func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.(*device.TunDevice).GetInterfaceGUIDString() -} diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index cb6d7ccd9ad..f5ab2953905 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type IWGIface interface { @@ -22,6 +23,7 @@ type IWGIface interface { ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error + GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index 6baeb66ae0e..96eec52a502 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgproxy" ) type IWGIface interface { @@ -20,6 +21,7 @@ type IWGIface interface { ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error + GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go new file mode 100644 index 00000000000..e986d6d7b07 --- /dev/null +++ b/client/iface/wgproxy/bind/proxy.go @@ -0,0 +1,137 @@ +package bind + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/bind" +) + +type ProxyBind struct { + Bind *bind.ICEBind + + wgAddr *net.UDPAddr + wgEndpoint *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool + + pausedMu sync.Mutex + paused bool + isStarted bool +} + +// AddTurnConn adds a new connection to the bind. +// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the +// WireGuard configuration. +func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { + addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) + if err != nil { + return err + } + + p.wgAddr = addr + p.wgEndpoint = addrToEndpoint(addr) + p.remoteConn = remoteConn + p.ctx, p.cancel = context.WithCancel(ctx) + return err + +} +func (p *ProxyBind) EndpointAddr() *net.UDPAddr { + return p.wgAddr +} + +func (p *ProxyBind) Work() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = false + p.pausedMu.Unlock() + + // Start the proxy only once + if !p.isStarted { + p.isStarted = true + go p.proxyToLocal(p.ctx) + } +} + +func (p *ProxyBind) Pause() { + if p.remoteConn == nil { + return + } + + p.pausedMu.Lock() + p.paused = true + p.pausedMu.Unlock() +} + +func (p *ProxyBind) CloseConn() error { + if p.cancel == nil { + return fmt.Errorf("proxy not started") + } + return p.close() +} + +func (p *ProxyBind) close() error { + p.closeMu.Lock() + defer p.closeMu.Unlock() + + if p.closed { + return nil + } + p.closed = true + + p.cancel() + + p.Bind.RemoveEndpoint(p.wgAddr) + + return p.remoteConn.Close() +} + +func (p *ProxyBind) proxyToLocal(ctx context.Context) { + defer func() { + if err := p.close(); err != nil { + log.Warnf("failed to close remote conn: %s", err) + } + }() + + buf := make([]byte, 1500) + for { + n, err := p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) + return + } + + p.pausedMu.Lock() + if p.paused { + p.pausedMu.Unlock() + continue + } + + msg := bind.RecvMessage{ + Endpoint: p.wgEndpoint, + Buffer: buf[:n], + } + p.Bind.RecvChan <- msg + p.pausedMu.Unlock() + } +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} +} diff --git a/client/internal/wgproxy/ebpf/portlookup.go b/client/iface/wgproxy/ebpf/portlookup.go similarity index 100% rename from client/internal/wgproxy/ebpf/portlookup.go rename to client/iface/wgproxy/ebpf/portlookup.go diff --git a/client/internal/wgproxy/ebpf/portlookup_test.go b/client/iface/wgproxy/ebpf/portlookup_test.go similarity index 100% rename from client/internal/wgproxy/ebpf/portlookup_test.go rename to client/iface/wgproxy/ebpf/portlookup_test.go diff --git a/client/internal/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go similarity index 99% rename from client/internal/wgproxy/ebpf/proxy.go rename to client/iface/wgproxy/ebpf/proxy.go index e850f4533ce..e21fc35d4e2 100644 --- a/client/internal/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error { p.ctxCancel() var result *multierror.Error - if p.conn != nil { // p.conn will be nil if we have failed to listen + if p.conn != nil { if err := p.conn.Close(); err != nil { result = multierror.Append(result, err) } diff --git a/client/internal/wgproxy/ebpf/proxy_test.go b/client/iface/wgproxy/ebpf/proxy_test.go similarity index 100% rename from client/internal/wgproxy/ebpf/proxy_test.go rename to client/iface/wgproxy/ebpf/proxy_test.go diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go similarity index 95% rename from client/internal/wgproxy/ebpf/wrapper.go rename to client/iface/wgproxy/ebpf/wrapper.go index b6a8ac45228..efd5fd946cf 100644 --- a/client/internal/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -28,7 +28,7 @@ type ProxyWrapper struct { isStarted bool } -func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go new file mode 100644 index 00000000000..32e96e34f2d --- /dev/null +++ b/client/iface/wgproxy/factory_kernel.go @@ -0,0 +1,47 @@ +//go:build linux && !android + +package wgproxy + +import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +type KernelFactory struct { + wgPort int + + ebpfProxy *ebpf.WGEBPFProxy +} + +func NewKernelFactory(wgPort int) *KernelFactory { + f := &KernelFactory{ + wgPort: wgPort, + } + + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) + if err := ebpfProxy.Listen(); err != nil { + log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) + return f + } + f.ebpfProxy = ebpfProxy + return f +} + +func (w *KernelFactory) GetProxy() Proxy { + if w.ebpfProxy == nil { + return udpProxy.NewWGUDPProxy(w.wgPort) + } + + return &ebpf.ProxyWrapper{ + WgeBPFProxy: w.ebpfProxy, + } +} + +func (w *KernelFactory) Free() error { + if w.ebpfProxy == nil { + return nil + } + return w.ebpfProxy.Free() +} diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go new file mode 100644 index 00000000000..7ac2f99a882 --- /dev/null +++ b/client/iface/wgproxy/factory_kernel_freebsd.go @@ -0,0 +1,26 @@ +package wgproxy + +import ( + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +// KernelFactory todo: check eBPF support on FreeBSD +type KernelFactory struct { + wgPort int +} + +func NewKernelFactory(wgPort int) *KernelFactory { + f := &KernelFactory{ + wgPort: wgPort, + } + + return f +} + +func (w *KernelFactory) GetProxy() Proxy { + return udpProxy.NewWGUDPProxy(w.wgPort) +} + +func (w *KernelFactory) Free() error { + return nil +} diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go new file mode 100644 index 00000000000..99f5ada017a --- /dev/null +++ b/client/iface/wgproxy/factory_usp.go @@ -0,0 +1,27 @@ +package wgproxy + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +type USPFactory struct { + bind *bind.ICEBind +} + +func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { + f := &USPFactory{ + bind: iceBind, + } + return f +} + +func (w *USPFactory) GetProxy() Proxy { + return &proxyBind.ProxyBind{ + Bind: w.bind, + } +} + +func (w *USPFactory) Free() error { + return nil +} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go new file mode 100644 index 00000000000..243aa2bd2a8 --- /dev/null +++ b/client/iface/wgproxy/proxy.go @@ -0,0 +1,15 @@ +package wgproxy + +import ( + "context" + "net" +) + +// Proxy is a transfer layer between the relayed connection and the WireGuard +type Proxy interface { + AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error + EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint + Work() // Work start or resume the proxy + Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + CloseConn() error +} diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go new file mode 100644 index 00000000000..298c98cc04f --- /dev/null +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -0,0 +1,56 @@ +//go:build linux && !android + +package wgproxy + +import ( + "context" + "os" + "testing" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" +) + +func TestProxyCloseByRemoteConnEBPF(t *testing.T) { + if os.Getenv("GITHUB_ACTIONS") != "true" { + t.Skip("Skipping test as it requires root privileges") + } + ctx := context.Background() + + ebpfProxy := ebpf.NewWGEBPFProxy(51831) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %s", err) + } + + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %s", err) + } + }() + + tests := []struct { + name string + proxy Proxy + }{ + { + name: "ebpf proxy", + proxy: &ebpf.ProxyWrapper{ + WgeBPFProxy: ebpfProxy, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + relayedConn := newMockConn() + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + if err != nil { + t.Errorf("error: %v", err) + } + + _ = relayedConn.Close() + if err := tt.proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} diff --git a/client/internal/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go similarity index 90% rename from client/internal/wgproxy/proxy_test.go rename to client/iface/wgproxy/proxy_test.go index b88ff3f83c1..64b61762112 100644 --- a/client/internal/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/internal/wgproxy/usp" + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }{ { name: "userspace proxy", - proxy: usp.NewWGUserSpaceProxy(51830), + proxy: udpProxy.NewWGUDPProxy(51830), }, } @@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, relayedConn) + err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) if err != nil { t.Errorf("error: %v", err) } diff --git a/client/internal/wgproxy/usp/proxy.go b/client/iface/wgproxy/udp/proxy.go similarity index 83% rename from client/internal/wgproxy/usp/proxy.go rename to client/iface/wgproxy/udp/proxy.go index f73500717a9..8bee099014e 100644 --- a/client/internal/wgproxy/usp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,4 +1,4 @@ -package usp +package udp import ( "context" @@ -12,8 +12,8 @@ import ( "github.com/netbirdio/netbird/client/errors" ) -// WGUserSpaceProxy proxies -type WGUserSpaceProxy struct { +// WGUDPProxy proxies +type WGUDPProxy struct { localWGListenPort int remoteConn net.Conn @@ -28,10 +28,10 @@ type WGUserSpaceProxy struct { isStarted bool } -// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation -func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { +// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation +func NewWGUDPProxy(wgPort int) *WGUDPProxy { log.Debugf("Initializing new user space proxy with port %d", wgPort) - p := &WGUserSpaceProxy{ + p := &WGUDPProxy{ localWGListenPort: wgPort, } return p @@ -42,7 +42,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { // the connection is complete, an error is returned. Once successfully // connected, any expiration of the context will not affect the // connection. -func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { +func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { dialer := net.Dialer{} localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { @@ -57,7 +57,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) return err } -func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { +func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr { if p.localConn == nil { return nil } @@ -66,7 +66,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { } // Work starts the proxy or resumes it if it was paused -func (p *WGUserSpaceProxy) Work() { +func (p *WGUDPProxy) Work() { if p.remoteConn == nil { return } @@ -83,7 +83,7 @@ func (p *WGUserSpaceProxy) Work() { } // Pause pauses the proxy from receiving data from the remote peer -func (p *WGUserSpaceProxy) Pause() { +func (p *WGUDPProxy) Pause() { if p.remoteConn == nil { return } @@ -94,14 +94,14 @@ func (p *WGUserSpaceProxy) Pause() { } // CloseConn close the localConn -func (p *WGUserSpaceProxy) CloseConn() error { +func (p *WGUDPProxy) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") } return p.close() } -func (p *WGUserSpaceProxy) close() error { +func (p *WGUDPProxy) close() error { p.closeMu.Lock() defer p.closeMu.Unlock() @@ -125,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error { } // proxyToRemote proxies from Wireguard to the RemoteKey -func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { +func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to remote loop: %s", err) @@ -157,7 +157,7 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { // proxyToLocal proxies from the Remote peer to local WireGuard // if the proxy is paused it will drain the remote conn and drop the packets -func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { +func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { log.Warnf("error in proxy to local loop: %s", err) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 53d18a67814..4a5aff3eaed 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -267,7 +267,17 @@ func TestUpdateDNSServer(t *testing.T) { if err != nil { t.Fatal(err) } - wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) + + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun230%d", n), + Address: fmt.Sprintf("100.66.100.%d/32", n+1), + WGPort: 33100, + WGPrivKey: privKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + + wgIface, err := iface.NewWGIFace(opts) if err != nil { t.Fatal(err) } @@ -345,7 +355,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: "utun2301", + Address: "100.66.100.1/32", + WGPort: 33100, + WGPrivKey: privKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgIface, err := iface.NewWGIFace(opts) if err != nil { t.Errorf("build interface wireguard: %v", err) return @@ -803,7 +821,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { } privKey, _ := wgtypes.GeneratePrivateKey() - wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) + + opts := iface.WGIFaceOpts{ + IFaceName: "utun2301", + Address: "100.66.100.2/24", + WGPort: 33100, + WGPrivKey: privKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + + wgIface, err := iface.NewWGIFace(opts) if err != nil { t.Fatalf("build interface wireguard: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index eac8ec098f6..459518de136 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -35,7 +35,6 @@ import ( "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/wgproxy" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -141,8 +140,7 @@ type Engine struct { ctx context.Context cancel context.CancelFunc - wgInterface iface.IWGIface - wgProxyFactory *wgproxy.Factory + wgInterface iface.IWGIface udpMux *bind.UniversalUDPMuxDefault @@ -299,9 +297,6 @@ func (e *Engine) Start() error { } e.wgInterface = wgIface - userspace := e.wgInterface.IsUserspaceBind() - e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort) - if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") if e.config.RosenpassPermissive { @@ -966,7 +961,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) if err != nil { return nil, err } @@ -1117,12 +1112,6 @@ func (e *Engine) parseNATExternalIPMappings() []string { } func (e *Engine) close() { - if e.wgProxyFactory != nil { - if err := e.wgProxyFactory.Free(); err != nil { - log.Errorf("failed closing ebpf proxy: %s", err) - } - } - log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { @@ -1167,21 +1156,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { log.Errorf("failed to create pion's stdnet: %s", err) } - var mArgs *device.MobileIFaceArguments + opts := iface.WGIFaceOpts{ + IFaceName: e.config.WgIfaceName, + Address: e.config.WgAddr, + WGPort: e.config.WgPort, + WGPrivKey: e.config.WgPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: transportNet, + FilterFn: e.addrViaRoutes, + } + switch runtime.GOOS { case "android": - mArgs = &device.MobileIFaceArguments{ + opts.MobileArgs = &device.MobileIFaceArguments{ TunAdapter: e.mobileDep.TunAdapter, TunFd: int(e.mobileDep.FileDescriptor), } case "ios": - mArgs = &device.MobileIFaceArguments{ + opts.MobileArgs = &device.MobileIFaceArguments{ TunFd: int(e.mobileDep.FileDescriptor), } - default: } - return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) + return iface.NewWGIFace(opts) } func (e *Engine) wgInterfaceCreate() (err error) { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 74b10ee44fa..d0ba1fffcf1 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -602,7 +602,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) + + opts := iface.WGIFaceOpts{ + IFaceName: wgIfaceName, + Address: wgAddr, + WGPort: engine.config.WgPort, + WGPrivKey: key.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + engine.wgInterface, err = iface.NewWGIFace(opts) assert.NoError(t, err, "shouldn't return error") input := struct { inputSerial uint64 @@ -774,7 +783,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { if err != nil { t.Fatal(err) } - engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: wgIfaceName, + Address: wgAddr, + WGPort: 33100, + WGPrivKey: key.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + engine.wgInterface, err = iface.NewWGIFace(opts) assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 1b740388d95..99acfde314e 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -17,8 +17,8 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/wgproxy" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -81,11 +81,10 @@ type Conn struct { ctxCancel context.CancelFunc config ConnConfig statusRecorder *Status - wgProxyFactory *wgproxy.Factory signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover relayManager *relayClient.Manager - allowedIPsIP string + allowedIP net.IP + allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -116,8 +115,8 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { - _, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps) +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { + allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -127,19 +126,17 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - signaler: signaler, - iFaceDiscover: iFaceDiscover, - relayManager: relayManager, - allowedIPsIP: allowedIPsIP.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), - + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + signaler: signaler, + relayManager: relayManager, + allowedIP: allowedIP, + allowedNet: allowedNet.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } @@ -692,7 +689,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) } } @@ -783,8 +780,13 @@ func (conn *Conn) freeUpConnID() { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.log.Debugf("setup proxied WireGuard connection") - wgProxy := conn.wgProxyFactory.GetProxy() - if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil { + udpAddr := &net.UDPAddr{ + IP: conn.allowedIP, + Port: conn.config.WgConfig.WgListenPort, + } + + wgProxy := conn.config.WgConfig.WgInterface.GetProxy() + if err := wgProxy.AddTurnConn(conn.ctx, udpAddr, remoteConn); err != nil { conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) return nil, err } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index b4926a9d2ef..e68861c5f04 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -11,7 +11,6 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/util" ) @@ -44,11 +43,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil) if err != nil { return } @@ -59,11 +54,7 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) if err != nil { return } @@ -96,11 +87,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) if err != nil { return } @@ -132,11 +119,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) if err != nil { return } diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 2f26f7a5ec9..044a996c777 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -407,7 +407,15 @@ func TestManagerUpdateRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun43%d", n), + Address: "100.65.65.2/24", + WGPort: 33100, + WGPrivKey: peerPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 238225807f8..ce5b6b8431b 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -61,7 +61,14 @@ func TestAddRemoveRoutes(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun53%d", n), + Address: "100.65.75.2/24", + WGPrivKey: peerPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -213,7 +220,15 @@ func TestAddExistAndRemoveRoute(t *testing.T) { if err != nil { t.Fatal(err) } - wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: fmt.Sprintf("utun53%d", n), + Address: "100.65.75.2/24", + WGPort: 33100, + WGPrivKey: peerPrivateKey.String(), + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WGIface interface") defer wgInterface.Close() @@ -345,7 +360,15 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen newNet, err := stdnet.NewNet() require.NoError(t, err) - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil) + opts := iface.WGIFaceOpts{ + IFaceName: interfaceName, + Address: ipAddressCIDR, + WGPrivKey: peerPrivateKey.String(), + WGPort: listenPort, + MTU: iface.DefaultMTU, + TransportNet: newNet, + } + wgInterface, err := iface.NewWGIFace(opts) require.NoError(t, err, "should create testing WireGuard interface") err = wgInterface.Create() diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go deleted file mode 100644 index 369ba99db1f..00000000000 --- a/client/internal/wgproxy/factory_linux.go +++ /dev/null @@ -1,50 +0,0 @@ -//go:build !android - -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" - "github.com/netbirdio/netbird/client/internal/wgproxy/usp" -) - -type Factory struct { - wgPort int - ebpfProxy *ebpf.WGEBPFProxy -} - -func NewFactory(userspace bool, wgPort int) *Factory { - f := &Factory{wgPort: wgPort} - - if userspace { - return f - } - - ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) - err := ebpfProxy.Listen() - if err != nil { - log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) - return f - } - - f.ebpfProxy = ebpfProxy - return f -} - -func (w *Factory) GetProxy() Proxy { - if w.ebpfProxy != nil { - p := &ebpf.ProxyWrapper{ - WgeBPFProxy: w.ebpfProxy, - } - return p - } - return usp.NewWGUserSpaceProxy(w.wgPort) -} - -func (w *Factory) Free() error { - if w.ebpfProxy == nil { - return nil - } - return w.ebpfProxy.Free() -} diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go deleted file mode 100644 index f930b09b3a0..00000000000 --- a/client/internal/wgproxy/factory_nonlinux.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !linux || android - -package wgproxy - -import "github.com/netbirdio/netbird/client/internal/wgproxy/usp" - -type Factory struct { - wgPort int -} - -func NewFactory(_ bool, wgPort int) *Factory { - return &Factory{wgPort: wgPort} -} - -func (w *Factory) GetProxy() Proxy { - return usp.NewWGUserSpaceProxy(w.wgPort) -} - -func (w *Factory) Free() error { - return nil -} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go deleted file mode 100644 index 558121cdd5a..00000000000 --- a/client/internal/wgproxy/proxy.go +++ /dev/null @@ -1,15 +0,0 @@ -package wgproxy - -import ( - "context" - "net" -) - -// Proxy is a transfer layer between the relayed connection and the WireGuard -type Proxy interface { - AddTurnConn(ctx context.Context, turnConn net.Conn) error - EndpointAddr() *net.UDPAddr - Work() - Pause() - CloseConn() error -} From 7bda385e1b2e3f6ecd570ba6c7ddedb4b62bccd8 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 23 Oct 2024 13:05:02 +0300 Subject: [PATCH 55/81] [management] Optimize network map updates (#2718) * Skip peer update on unchanged network map (#2236) * Enhance network updates by skipping unchanged messages Optimizes the network update process by skipping updates where no changes in the peer update message received. * Add unit tests * add locks * Improve concurrency and update peer message handling * Refactor account manager network update tests * fix test * Fix inverted network map update condition * Add default group and policy to test data * Run peer updates in a separate goroutine * Refactor * Refactor lock * Fix peers update by including NetworkMap and posture Checks * go mod tidy * fix merge Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * [management] Skip account peers update if no changes affect peers (#2310) * Remove incrementing network serial and updating peers after group deletion * Update account peer if posture check is linked to policy * Remove account peers update on saving setup key * Refactor group link checking into re-usable functions * Add HasPeers function to group * Refactor group management * Optimize group change effects on account peers * Update account peers if ns group has peers * Refactor group changes * Optimize account peers update in DNS settings * Optimize update of account peers on jwt groups sync * Refactor peer account updates for efficiency * Optimize peer update on user deletion and changes * Remove condition check for network serial update * Optimize account peers updates on route changes * Remove UpdatePeerSSHKey method * Remove unused isPolicyRuleGroupsEmpty * Add tests for peer update behavior on posture check changes * Add tests for peer update behavior on policy changes * Add tests for peer update behavior on group changes * Add tests for peer update behavior on dns settings changes * Refactor * Add tests for peer update behavior on name server changes * Add tests for peer update behavior on user changes * Add tests for peer update behavior on route changes * fix tests * Add tests for peer update behavior on setup key changes * Add tests for peer update behavior on peers changes * fix merge * Fix tests * go mod tidy * Add NameServer and Route comparators * Update network map diff logic with custom comparators * Add tests * Refactor duplicate diff handling logic * fix linter * fix tests * Refactor policy group handling and update logic. Signed-off-by: bcmmbaga * Update route check by checking if group has peers Signed-off-by: bcmmbaga * Refactor posture check policy linking logic Signed-off-by: bcmmbaga * Simplify peer update condition in DNS management Refactor the condition for updating account peers to remove redundant checks Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * add policy tests Signed-off-by: bcmmbaga * add posture checks tests Signed-off-by: bcmmbaga * fix user and setup key tests Signed-off-by: bcmmbaga * fix account and route tests Signed-off-by: bcmmbaga * fix typo Signed-off-by: bcmmbaga * fix nameserver tests Signed-off-by: bcmmbaga * fix routes tests Signed-off-by: bcmmbaga * fix group tests Signed-off-by: bcmmbaga * upgrade diff package Signed-off-by: bcmmbaga * fix nameserver tests Signed-off-by: bcmmbaga * use generic differ for netip.Addr and netip.Prefix Signed-off-by: bcmmbaga * go mod tidy Signed-off-by: bcmmbaga * add peer tests Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix management suite tests Signed-off-by: bcmmbaga * fix postgres tests Signed-off-by: bcmmbaga * enable diff nil structs comparison Signed-off-by: bcmmbaga * skip the update only last sent the serial is larger Signed-off-by: bcmmbaga * refactor peer and user Signed-off-by: bcmmbaga * skip spell check for groupD Signed-off-by: bcmmbaga * Refactor group, ns group, policy and posture checks Signed-off-by: bcmmbaga * skip spell check for GroupD Signed-off-by: bcmmbaga * update account policy check before verifying policy status Signed-off-by: bcmmbaga * Update management/server/route_test.go Co-authored-by: Maycon Santos * Update management/server/route_test.go Co-authored-by: Maycon Santos * Update management/server/route_test.go Co-authored-by: Maycon Santos * Update management/server/route_test.go Co-authored-by: Maycon Santos * Update management/server/route_test.go Co-authored-by: Maycon Santos * add tests missing tests for dns setting groups Signed-off-by: bcmmbaga * add tests for posture checks changes Signed-off-by: bcmmbaga * add ns group and policy tests Signed-off-by: bcmmbaga * add route and group tests Signed-off-by: bcmmbaga * increase Linux test timeout to 10 minutes Signed-off-by: bcmmbaga * Run diff for client posture checks only Signed-off-by: bcmmbaga * add panic recovery and detailed logging in peer update comparison Signed-off-by: bcmmbaga * Fix tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga Co-authored-by: Maycon Santos --------- Signed-off-by: bcmmbaga Co-authored-by: Maycon Santos --- .github/workflows/golang-test-linux.yml | 2 +- .github/workflows/golangci-lint.yml | 2 +- go.mod | 3 + go.sum | 6 + management/server/account.go | 7 +- management/server/account_test.go | 384 +++++++++----- management/server/differs/netip.go | 82 +++ management/server/dns.go | 9 +- management/server/dns_test.go | 144 ++++++ management/server/group.go | 50 +- management/server/group/group.go | 5 + management/server/group_test.go | 312 ++++++++++++ management/server/mock_server/account_mock.go | 9 - management/server/nameserver.go | 34 +- management/server/nameserver_test.go | 178 +++++++ management/server/network.go | 4 +- management/server/peer.go | 81 ++- management/server/peer/peer.go | 21 +- management/server/peer_test.go | 319 ++++++++++++ management/server/policy.go | 34 +- management/server/policy_test.go | 374 ++++++++++++++ management/server/posture_checks.go | 35 +- management/server/posture_checks_test.go | 458 +++++++++++++++++ management/server/route.go | 19 +- management/server/route_test.go | 279 ++++++++++ management/server/setupkey.go | 2 - management/server/setupkey_test.go | 71 +++ management/server/testdata/store.sql | 5 +- management/server/updatechannel.go | 109 +++- management/server/updatechannel_test.go | 476 ++++++++++++++++++ management/server/user.go | 60 ++- management/server/user_test.go | 165 ++++++ 32 files changed, 3472 insertions(+), 267 deletions(-) create mode 100644 management/server/differs/netip.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index e1e1ff2362e..9457d3a6621 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 2d743f79071..dacb1922be9 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif + ignore_words_list: erro,clienta,hastable,iif,groupd skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/go.mod b/go.mod index e7e3c17a68a..a6b83794dab 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/r3labs/diff/v3 v3.0.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -210,6 +211,8 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect + github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go.sum b/go.sum index e9bc318d6fd..412542d5eb9 100644 --- a/go.sum +++ b/go.sum @@ -605,6 +605,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= +github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -697,6 +699,10 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= +github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/account.go b/management/server/account.go index b49b82f9128..a8a244bdf1f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -102,7 +102,6 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) @@ -2132,8 +2131,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error getting account: %w", err) } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } } return nil diff --git a/management/server/account_test.go b/management/server/account_test.go index 19514dad181..3c3fcebc67f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1122,77 +1122,108 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - return - } - - userID := "account_creator" +func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - account, err := createAccount(manager, "test_account", userID, "") - if err != nil { - t.Fatal(err) + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{}, } - - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - t.Fatal("error creating setup key") + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) return } - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, } + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) - getPeer := func() *nbpeer.Peer { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return nil - } - expectedPeerKey := key.PublicKey().String() + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ - Key: expectedPeerKey, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - if err != nil { - t.Fatalf("expecting peer1 to be added, got failure %v", err) - return nil + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) } + }() - return peer + group.Peers = []string{peer1.ID, peer2.ID, peer3.ID} + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return } - peer1 := getPeer() - peer2 := getPeer() - peer3 := getPeer() + wg.Wait() +} - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - return - } +func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + group := group.Group{ - ID: "group-id", + ID: "groupA", Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + Peers: []string{peer1.ID, peer2.ID}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return } + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + policy := Policy{ Enabled: true, Rules: []*PolicyRule{ { Enabled: true, - Sources: []string{"group-id"}, - Destinations: []string{"group-id"}, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, Bidirectional: true, Action: PolicyTrafficActionAccept, }, @@ -1200,107 +1231,138 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { } wg := sync.WaitGroup{} - t.Run("save group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() + wg.Add(1) + go func() { + defer wg.Done() - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { - t.Errorf("save group: %v", err) - return + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("delete default rule: %v", err) + return + } - t.Run("delete policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Wait() +} - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() +func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, + } + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) + return + } - wg.Wait() - }) + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } - t.Run("save policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("save policy: %v", err) + return + } - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { - t.Errorf("delete default rule: %v", err) - return + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 1 { + t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) - t.Run("delete peer update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 1 { - t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) - } - }() + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { + t.Errorf("delete peer: %v", err) + return + } - if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { - t.Errorf("delete peer: %v", err) - return - } + wg.Wait() +} - wg.Wait() - }) +func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - t.Run("delete group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() + group := group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + } - // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("save policy: %v", err) + return + } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { - t.Errorf("delete group: %v", err) - return + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + // clean policy is pre requirement for delete group + if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + t.Errorf("delete group: %v", err) + return + } + + wg.Wait() } func TestAccountManager_DeletePeer(t *testing.T) { @@ -2754,3 +2816,73 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true } } + +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { + t.Helper() + + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + account, err := createAccount(manager, "test_account", userID, "") + if err != nil { + t.Fatal(err) + } + + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + if err != nil { + t.Fatal("error creating setup key") + } + + getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + expectedPeerKey := key.PublicKey().String() + + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + Status: &nbpeer.PeerStatus{ + Connected: true, + LastSeen: time.Now().UTC(), + }, + }) + if err != nil { + t.Fatalf("expecting peer to be added, got failure %v", err) + } + + return peer + } + + peer1 := getPeer(manager, setupKey) + peer2 := getPeer(manager, setupKey) + peer3 := getPeer(manager, setupKey) + + return manager, account, peer1, peer2, peer3 +} + +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + case <-time.After(500 * time.Millisecond): + t.Error("Timed out waiting for update message") + } +} diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go new file mode 100644 index 00000000000..de4aa334c17 --- /dev/null +++ b/management/server/differs/netip.go @@ -0,0 +1,82 @@ +package differs + +import ( + "fmt" + "net/netip" + "reflect" + + "github.com/r3labs/diff/v3" +) + +// NetIPAddr is a custom differ for netip.Addr +type NetIPAddr struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPAddr) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) +} + +func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromAddr, ok1 := a.Interface().(netip.Addr) + toAddr, ok2 := b.Interface().(netip.Addr) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromAddr.String() != toAddr.String() { + cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) + } + + return nil +} + +func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} + +// NetIPPrefix is a custom differ for netip.Prefix +type NetIPPrefix struct { + DiffFunc func(path []string, a, b reflect.Value, p interface{}) error +} + +func (differ NetIPPrefix) Match(a, b reflect.Value) bool { + return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) +} + +func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { + if a.Kind() == reflect.Invalid { + cl.Add(diff.CREATE, path, nil, b.Interface()) + return nil + } + if b.Kind() == reflect.Invalid { + cl.Add(diff.DELETE, path, a.Interface(), nil) + return nil + } + + fromPrefix, ok1 := a.Interface().(netip.Prefix) + toPrefix, ok2 := b.Interface().(netip.Prefix) + if !ok1 || !ok2 { + return fmt.Errorf("invalid type for netip.Addr") + } + + if fromPrefix.String() != toPrefix.String() { + cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) + } + + return nil +} + +func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { + differ.DiffFunc = dfunc //nolint +} diff --git a/management/server/dns.go b/management/server/dns.go index 7410aaa15cc..256b8b12512 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -125,26 +125,29 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID oldSettings := account.DNSSettings.Copy() account.DNSSettings = dnsSettingsToSave.Copy() + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) for _, id := range addedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) } - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) for _, id := range removedGroups { group := account.GetGroup(id) meta := map[string]any{"group": group.Name, "group_id": group.ID} am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - am.updateAccountPeers(ctx, account) + if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + am.updateAccountPeers(ctx, account) + } return nil } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c7f435b688d..c675fc12c84 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -6,9 +6,11 @@ import ( "net/netip" "reflect" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -476,3 +478,145 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Cache should contain name server group 'group2'") } } + +func TestDNSAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates + t.Run("saving dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + // Saving DNS settings with groups that have peers should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA", "groupB"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged DNS settings with used groups should update account peers and not send peer update + // since there is no change in the network map + t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA", "groupB"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates + t.Run("removing group with no peers from dns settings", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupA"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing group with peers from DNS settings should trigger updates to account peers and send peer updates + t.Run("removing group with peers from dns settings", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/group.go b/management/server/group.go index aa387c058ea..bdb569e377f 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -121,12 +121,19 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user eventsToStore = append(eventsToStore, events...) } + newGroupIDs := make([]string, 0, len(newGroups)) + for _, newGroup := range newGroups { + newGroupIDs = append(newGroupIDs, newGroup.ID) + } + account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, newGroupIDs) { + am.updateAccountPeers(ctx, account) + } for _, storeEvent := range eventsToStore { storeEvent() @@ -238,8 +245,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) - am.updateAccountPeers(ctx, account) - return nil } @@ -282,8 +287,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) } - am.updateAccountPeers(ctx, account) - return allErrors } @@ -336,7 +339,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, []string{group.ID}) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -366,7 +371,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, []string{group.ID}) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -469,3 +476,32 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { } return false, nil } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if group, exists := account.Groups[groupID]; exists && group.HasPeers() { + return true + } + } + return false +} + +func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { + return true + } + if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { + return true + } + if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { + return true + } + if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { + return true + } + } + + return false +} diff --git a/management/server/group/group.go b/management/server/group/group.go index 79dfd995ce0..d293e1afc6f 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -44,3 +44,8 @@ func (g *Group) Copy() *Group { copy(group.Peers, g.Peers) return group } + +// HasPeers checks if the group has any peers. +func (g *Group) HasPeers() bool { + return len(g.Peers) > 0 +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 89b68ad6c07..1e59b74ef5b 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -4,13 +4,16 @@ import ( "context" "errors" "fmt" + "net/netip" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -384,3 +387,312 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A } return am, acc, nil } + +func TestGroupAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1.ID, peer3.ID}, + }, + { + ID: "groupD", + Name: "GroupD", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving a group that is not linked to any resource should not update account peers + t.Run("saving unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding a peer to a group that is not linked to any resource should not update account peers + // and not send peer update + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupAddPeer(context.Background(), account.Id, "groupB", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing a peer from a group that is not linked to any resource should not update account peers + // and not send peer update + t.Run("removing peer from unliked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupDeletePeer(context.Background(), account.Id, "groupB", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting group should not update account peers and not send peer update + t.Run("deleting group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupB") + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // adding a group to policy + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }, false) + assert.NoError(t, err) + + // Saving a group linked to policy should update account peers and send peer update + t.Run("saving linked group to policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving an unchanged group should trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("saving unchanged group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // adding peer to a used group should update account peers and send peer update + t.Run("adding peer to linked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupAddPeer(context.Background(), account.Id, "groupA", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // removing peer from a linked group should update account peers and send peer update + t.Run("removing peer from linked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.GroupDeletePeer(context.Background(), account.Id, "groupA", peer3.ID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to name server group should update account peers and send peer update + t.Run("saving group linked to name server group", func(t *testing.T) { + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupC"}, + true, nil, true, userID, false, + ) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to route should update account peers and send peer update + t.Run("saving group linked to route", func(t *testing.T) { + newRoute := route.Route{ + ID: "route", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupA"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupC"}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving a group linked to dns settings should update account peers and send peer update + t.Run("saving group linked to dns settings", func(t *testing.T) { + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"groupD"}, + }) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupD", + Name: "GroupD", + Peers: []string{peer1.ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 74557e2275c..681bf533ae4 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -57,7 +57,6 @@ type MockAccountManager struct { GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -434,14 +433,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ( return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } -// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager -func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if am.UpdatePeerSSHKeyFunc != nil { - return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey) - } - return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented") -} - // UpdatePeer mocks UpdatePeerFunc function of the account manager func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { if am.UpdatePeerFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 0eb5d9ae4a4..5ebd263dcc2 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -66,13 +66,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco account.NameServerGroups[newNSGroup.ID] = newNSGroup account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, newNSGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil @@ -80,7 +80,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -98,16 +97,17 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } + oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) - + if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil @@ -131,13 +131,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco delete(account.NameServerGroups, nsGroupID) account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, nsGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil @@ -277,3 +277,11 @@ func validateDomain(domain string) error { return nil } + +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false + } + return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 8a3fe6eb049..96637cd39a0 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -4,7 +4,9 @@ import ( "context" "net/netip" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" @@ -935,3 +937,179 @@ func TestValidateDomain(t *testing.T) { } } + +func TestNameServerAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + var newNameServerGroupA *nbdns.NameServerGroup + var newNameServerGroupB *nbdns.NameServerGroup + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Creating a nameserver group with a distribution group no peers should not update account peers + // and not send peer update + t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupA, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroupA", "nsGroupA", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // saving a nameserver group with a distribution group with no peers should not update account peers + // and not send peer update + t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupA) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Creating a nameserver group with a distribution group no peers should update account peers and send peer update + t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroupB", "nsGroupB", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupB"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // saving a nameserver group with a distribution group with peers should update account peers and send peer update + t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB.NameServers = []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + } + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // saving unchanged nameserver group should update account peers and not send peer update + t.Run("saving unchanged nameserver group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + newNameServerGroupB.NameServers = []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.2"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }, + } + err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting a nameserver group should update account peers and send peer update + t.Run("deleting nameserver group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeleteNameServerGroup(context.Background(), account.Id, newNameServerGroupB.ID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/network.go b/management/server/network.go index a5b188b4610..8fb6a8b3c12 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -41,9 +41,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 + Serial uint64 `diff:"-"` - mu sync.Mutex `json:"-" gorm:"-"` + mu sync.Mutex `json:"-" gorm:"-" diff:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer.go b/management/server/peer.go index a4c7e126675..80d43497a70 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "slices" "strings" "sync" "time" @@ -200,7 +201,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - if peer.Name != update.Name { + peerLabelUpdated := peer.Name != update.Name + + if peerLabelUpdated { peer.Name = update.Name existingLabels := account.getPeerDNSLabels() @@ -260,7 +263,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - am.updateAccountPeers(ctx, account) + if peerLabelUpdated { + am.updateAccountPeers(ctx, account) + } return peer, nil } @@ -304,6 +309,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, + NetworkMap: &NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -322,6 +328,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + updateAccountPeers := isPeerInActiveGroup(account, peerID) + err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { return err @@ -332,7 +340,9 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -422,9 +432,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var newPeer *nbpeer.Peer + var groupsToAdd []string err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - var groupsToAdd []string var setupKeyID string var setupKeyName string var ephemeral bool @@ -576,7 +586,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, fmt.Errorf("error getting account: %w", err) } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, groupsToAdd) { + am.updateAccountPeers(ctx, account) + } approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { @@ -897,51 +909,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings return false } -// UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if sshKey == "" { - log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) - return nil - } - - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, account.Id) - if err != nil { - return err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - if peer.SSHKey == sshKey { - log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) - return nil - } - - peer.SSHKey = sshKey - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err - } - - // trigger network map update - am.updateAccountPeers(ctx, account) - - return nil -} - // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -1034,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks := am.getPeerPostureChecks(account, p) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1048,3 +1015,15 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { } return labelMap } + +// IsPeerInActiveGroup checks if the given peer is part of a group that is used +// in an active DNS, route, or ACL configuration. +func isPeerInActiveGroup(account *Account, peerID string) bool { + peerGroupIDs := make([]string, 0) + for _, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + peerGroupIDs = append(peerGroupIDs, group.ID) + } + } + return areGroupChangesAffectPeers(account, peerGroupIDs) +} diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 9a53459a8c8..ef96bce7dd8 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -17,37 +17,37 @@ type Peer struct { // WireGuard public key Key string `gorm:"index"` // A setup key this peer was registered with - SetupKey string + SetupKey string `diff:"-"` // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` // The user ID that registered the peer - UserID string + UserID string `diff:"-"` // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool + LoginExpirationEnabled bool `diff:"-"` - InactivityExpirationEnabled bool + InactivityExpirationEnabled bool `diff:"-"` // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin time.Time `diff:"-"` // CreatedAt records the time the peer was created - CreatedAt time.Time + CreatedAt time.Time `diff:"-"` // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `diff:"-"` // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_"` + Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` } type PeerStatus struct { //nolint:revive @@ -189,7 +189,6 @@ func (p *Peer) Copy() *Peer { CreatedAt: p.CreatedAt, Ephemeral: p.Ephemeral, Location: p.Location, - InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c5edb5636ad..7b2180bf019 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1253,3 +1253,322 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) } + +func TestPeerAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) + require.NoError(t, err) + + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + require.NoError(t, err) + + // create a user with auto groups + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + { + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupA"}, + }, + { + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupB"}, + }, + { + Id: "regularUser3", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"groupC"}, + }, + }, true) + require.NoError(t, err) + + var peer4 *nbpeer.Peer + var peer5 *nbpeer.Peer + var peer6 *nbpeer.Peer + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update + t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding peer to unlinked group should not update account peers and not send peer update + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting peer with unlinked group should not update account peers and not send peer update + t.Run("deleting peer with unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating peer label should update account peers and send peer update + t.Run("updating peer label", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + peer1.Name = "peer-1" + _, err = manager.UpdatePeer(context.Background(), account.Id, userID, peer1) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with policy should update account peers and send peer update + t.Run("adding peer to group linked with policy", func(t *testing.T) { + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }, false) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to policy should update account peers and send peer update + t.Run("deleting peer with linked group to policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with route should update account peers and send peer update + t.Run("adding peer to group linked with route", func(t *testing.T) { + route := nbroute.Route{ + ID: "testingRoute1", + Network: netip.MustParsePrefix("100.65.250.202/32"), + NetID: "superNet", + NetworkType: nbroute.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupB"}, + } + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to route should update account peers and send peer update + t.Run("deleting peer with linked group to route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer5.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to group linked with name server group should update account peers and send peer update + t.Run("adding peer to group linked with name server group", func(t *testing.T) { + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + []string{"groupC"}, + true, []string{}, true, userID, false, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting peer with linked group to name server group should update account peers and send peer update + t.Run("deleting peer with linked group to route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer6.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/policy.go b/management/server/policy.go index 75647de449b..05554243032 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -203,6 +203,18 @@ func (p *Policy) UpgradeAndFix() { } } +// ruleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) ruleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} + // FirewallRule is a rule of the firewall. type FirewallRule struct { // PeerIP of the peer @@ -348,7 +360,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = am.savePolicy(account, policy, isUpdate); err != nil { + updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if err != nil { return err } @@ -363,7 +376,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -428,7 +443,7 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) // savePolicy saves or updates a policy in the given account. // If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { for index, rule := range policyToSave.Rules { rule.Sources = filterValidGroupIDs(account, rule.Sources) rule.Destinations = filterValidGroupIDs(account, rule.Destinations) @@ -442,18 +457,25 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if isUpdate { policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) if policyIdx < 0 { - return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } + oldPolicy := account.Policies[policyIdx] // Update the existing policy account.Policies[policyIdx] = policyToSave - return nil + + if !policyToSave.Enabled && !oldPolicy.Enabled { + return false, nil + } + updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + + return updateAccountPeers, nil } // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return nil + return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/policy_test.go b/management/server/policy_test.go index bf9a53d16dd..5b1411702b2 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -5,7 +5,9 @@ import ( "fmt" "net" "testing" + "time" + "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -824,3 +826,375 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { return 0 // a is equal to b } } + +func TestPolicyAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + { + ID: "groupD", + Name: "GroupD", + Peers: []string{peer1.ID, peer2.ID}, + }, + }) + assert.NoError(t, err) + + updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + }) + + // Saving policy with rule groups with no peers should not update account's peers and not send peer update + t.Run("saving policy with rule groups with no peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-rule-groups-no-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Saving policy with source group containing peers, but destination group without peers should + // update account's peers and send peer update + t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-has-peers-destination-none", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving policy with destination group containing peers, but source group without peers should + // update account's peers and send peer update + t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { + policy := Policy{ + ID: "policy-destination-has-peers-source-none", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: false, + Sources: []string{"groupC"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg2) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Disabling policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: false, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating disabled policy with destination and source groups containing peers should not update account's peers + // or send peer update + t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Description: "updated description", + Enabled: false, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Enabling policy with destination and source groups containing peers should update account's peers + // and send peer update + t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged policy should trigger account peers update but not send peer update + t.Run("saving unchanged policy", func(t *testing.T) { + policy := Policy{ + ID: "policy-source-destination-peers", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupD"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting policy should trigger account peers update and send peer update + t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { + policyID := "policy-source-destination-peers" + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) + + // Deleting policy with destination group containing peers, but source group without peers should + // update account's peers and send peer update + t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { + policyID := "policy-destination-has-peers-source-none" + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg2) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting policy with no peers in groups should not update account's peers and not send peer update + t.Run("deleting policy with no peers in groups", func(t *testing.T) { + policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2 + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg1) + close(done) + }() + + err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9a4b679cef5..2dccd8f590c 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -67,7 +67,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if exists { + + if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { am.updateAccountPeers(ctx, account) } @@ -148,13 +149,9 @@ func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureCh return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) } - // check policy links - for _, policy := range account.Policies { - for _, id := range policy.SourcePostureChecks { - if id == postureChecksID { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) - } - } + // Check if posture check is linked to any policy + if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { + return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) } postureChecks := account.PostureChecks[postureChecksIdx] @@ -217,3 +214,25 @@ func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks } } } + +func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { + for _, policy := range account.Policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return true, policy + } + } + return false, nil +} + +// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { + if !exists { + return false + } + + isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) + if !isLinked { + return false + } + return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) +} diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index d837120f462..7d31956f955 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -3,7 +3,10 @@ package server import ( "context" "testing" + "time" + "github.com/netbirdio/netbird/management/server/group" + "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/posture" @@ -118,3 +121,458 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { return am.Store.GetAccount(context.Background(), account.Id) } + +func TestPostureCheckAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + postureCheck := posture.Checks{ + ID: "postureCheck", + Name: "postureCheck", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.28.0", + }, + }, + } + + // Saving unused posture check should not update account peers and not send peer update + t.Run("saving unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating unused posture check should not update account peers and not send peer update + t.Run("updating unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + policy := Policy{ + ID: "policyA", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + + // Linking posture check to policy should trigger update account peers and send peer update + t.Run("linking posture check to policy with peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating linked posture checks should update account peers and send peer update + t.Run("updating linked to posture check with peers", func(t *testing.T) { + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged posture check should not trigger account peers update and not send peer update + // since there is no change in the network map + t.Run("saving unchanged posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Removing posture check from policy should trigger account peers update and send peer update + t.Run("removing posture check from policy", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policy.SourcePostureChecks = []string{} + + err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Deleting unused posture check should not trigger account peers update and not send peer update + t.Run("deleting unused posture check", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update + t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating linked posture check to policy where destination has peers but source does not + // should trigger account peers update and send peer update + t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { + updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + }) + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + Enabled: true, + Sources: []string{"groupB"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg1) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating linked posture check to policy where source has peers but destination does not, + // should not trigger account peers update or send peer update + t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.29.0", + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Updating linked client posture check to policy where source has peers but destination does not, + // should trigger account peers update and send peer update + t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) { + policy = Policy{ + ID: "policyB", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupB"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{postureCheck.ID}, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + assert.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + postureCheck.Checks = posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + }, + }, + }, + } + err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} + +func TestArePostureCheckChangesAffectingPeers(t *testing.T) { + account := &Account{ + Policies: []*Policy{ + { + ID: "policyA", + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{"checkA"}, + }, + }, + Groups: map[string]*group.Group{ + "groupA": { + ID: "groupA", + Peers: []string{"peer1"}, + }, + "groupB": { + ID: "groupB", + Peers: []string{}, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: "checkA", + }, + { + ID: "checkB", + }, + }, + } + + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + assert.False(t, result) + }) + + t.Run("posture check does not exist", func(t *testing.T) { + result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + assert.False(t, result) + }) + + t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"groupB"} + account.Policies[0].Rules[0].Destinations = []string{"groupA"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"groupA"} + account.Policies[0].Rules[0].Destinations = []string{"groupB"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.True(t, result) + }) + + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} + account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.False(t, result) + }) + + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + account.Groups["groupA"].Peers = []string{} + result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + assert.False(t, result) + }) +} diff --git a/management/server/route.go b/management/server/route.go index 39ee6170c77..1cf00b37c46 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, &newRoute) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -313,6 +315,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + oldRoute := account.Routes[routeToSave.ID] account.Routes[routeToSave.ID] = routeToSave account.Network.IncSerial() @@ -320,7 +323,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -350,7 +355,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - am.updateAccountPeers(ctx, account) + if isRouteChangeAffectPeers(account, routy) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -641,3 +648,9 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { } return &portInfo } + +// isRouteChangeAffectPeers checks if a given route affects peers by determining +// if it has a routing peer, distribution, or peer groups that include peers +func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +} diff --git a/management/server/route_test.go b/management/server/route_test.go index 09cbe53ff53..a4b320c7ee2 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "testing" + "time" "github.com/rs/xid" "github.com/stretchr/testify/assert" @@ -1777,3 +1778,281 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }) } + +func TestRouteAccountPeersUpdate(t *testing.T) { + manager, err := createRouterManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestRouteAccount(t, manager) + require.NoError(t, err, "failed to init testing account") + + err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + { + ID: "groupA", + Name: "GroupA", + Peers: []string{}, + }, + { + ID: "groupB", + Name: "GroupB", + Peers: []string{}, + }, + { + ID: "groupC", + Name: "GroupC", + Peers: []string{}, + }, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + }) + + // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update + t.Run("creating route no routing peer and no peers in groups", func(t *testing.T) { + route := route.Route{ + ID: "testingRoute1", + Network: netip.MustParsePrefix("100.65.250.202/32"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupA"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupA"}, + } + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + + }) + + // Creating a route with no routing peer and having peers in groups should update account peers and send peer update + t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) { + route := route.Route{ + ID: "testingRoute2", + Network: netip.MustParsePrefix("192.0.2.0/32"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroup3}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup3}, + } + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.CreateRoute( + context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, + route.Groups, []string{}, true, userID, route.KeepRoute, + ) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + + }) + + baseRoute := route.Route{ + ID: "testingRoute3", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + + // Creating route should update account peers and send peer update + t.Run("creating route with a routing peer", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + newRoute, err := manager.CreateRoute( + context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, + baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, + baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, + ) + require.NoError(t, err) + baseRoute = *newRoute + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating the route should update account peers and send peer update when there is peers in group + t.Run("updating route", func(t *testing.T) { + baseRoute.Groups = []string{routeGroup1, routeGroup2} + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Updating unchanged route should update account peers and not send peer update + t.Run("updating unchanged route", func(t *testing.T) { + baseRoute.Groups = []string{routeGroup1, routeGroup2} + + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting the route should update account peers and send peer update + t.Run("deleting route", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to route peer groups that do not have any peers should update account peers and send peer update + t.Run("adding peer to route peer groups that do not have any peers", func(t *testing.T) { + newRoute := route.Route{ + Network: netip.MustParsePrefix("192.168.12.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupB", + Name: "GroupB", + Peers: []string{peer1ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Adding peer to route groups that do not have any peers should update account peers and send peer update + t.Run("adding peer to route groups that do not have any peers", func(t *testing.T) { + newRoute := route.Route{ + Network: netip.MustParsePrefix("192.168.13.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{"groupB"}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"groupC"}, + } + _, err := manager.CreateRoute( + context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, + newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, + ) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupC", + Name: "GroupC", + Peers: []string{peer1ID}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 9521e22d339..e84f8fcd687 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -323,8 +323,6 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } }() - am.updateAccountPeers(ctx, account) - return newKey, nil } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index aa5075b024e..651b5401047 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -352,3 +353,73 @@ func TestSetupKey_Copy(t *testing.T) { key.UpdatedAt, key.AutoGroups) } + +func TestSetupKeyAccountPeersUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"group"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + var setupKey *SetupKey + + // Creating setup key should not update account peers and not send peer update + t.Run("creating setup key", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Saving setup key should not update account peers and not send peer update + t.Run("saving setup key", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveSetupKey(context.Background(), account.Id, setupKey, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) +} diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 32a59128bf1..168973cad91 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -26,8 +26,11 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO installations VALUES(1,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 0188cef52a9..6fb96c97124 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,9 +2,13 @@ package server import ( "context" + "fmt" + "runtime/debug" "sync" "time" + "github.com/netbirdio/netbird/management/server/differs" + "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -14,14 +18,17 @@ import ( const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + NetworkMap *NetworkMap } type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage + // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. + peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels - channelsMux *sync.Mutex + channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } @@ -29,9 +36,10 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - channelsMux: &sync.Mutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + peerUpdateMessage: make(map[string]*UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool + // skip sending sync update to the peer if there is no change in update message, + // it will not check on turn credential refresh as we do not send network map or client posture checks + if update.NetworkMap != nil { + updated := p.handlePeerMessageUpdate(ctx, peerID, update) + if !updated { + return + } + } + p.channelsMux.Lock() + defer func() { p.channelsMux.Unlock() if p.metrics != nil { @@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() + if update.NetworkMap != nil { + lastSentUpdate := p.peerUpdateMessage[peerID] + if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() { + log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", + peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) + return + } + p.peerUpdateMessage[peerID] = update + } + if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -170,3 +200,72 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. +func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { + p.channelsMux.RLock() + lastSentUpdate := p.peerUpdateMessage[peerID] + p.channelsMux.RUnlock() + + if lastSentUpdate != nil { + updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update) + if err != nil { + log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) + return false + } + if !updated { + log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) + return false + } + } + + return true +} + +// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. +func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) { + defer func() { + if r := recover(); r != nil { + log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) + isNew, err = true, nil + } + }() + + if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { + return false, nil + } + + differ, err := diff.NewDiffer( + diff.CustomValueDiffers(&differs.NetIPAddr{}), + diff.CustomValueDiffers(&differs.NetIPPrefix{}), + ) + if err != nil { + return false, fmt.Errorf("failed to create differ: %v", err) + } + + lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks) + currFiles := getChecksFiles(currUpdateToSend.Update.Checks) + + changelog, err := differ.Diff(lastSentFiles, currFiles) + if err != nil { + return false, fmt.Errorf("failed to diff checks: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) + if err != nil { + return false, fmt.Errorf("failed to diff network map: %v", err) + } + return len(changelog) > 0, nil +} + +// getChecksFiles returns a list of files from the given checks. +func getChecksFiles(checks []*proto.Checks) []string { + files := make([]string, 0, len(checks)) + for _, check := range checks { + files = append(files, check.GetFiles()...) + } + return files +} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c45..52b715e9503 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -2,10 +2,19 @@ package server import ( "context" + "net" + "net/netip" "testing" "time" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/util" + "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -77,3 +86,470 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } + +func TestHandlePeerMessageUpdate(t *testing.T) { + tests := []struct { + name string + peerID string + existingUpdate *UpdateMessage + newUpdate *UpdateMessage + expectedResult bool + }{ + { + name: "update message with turn credentials update", + peerID: "peer", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{}, + }, + }, + expectedResult: true, + }, + { + name: "update message for peer without existing update", + peerID: "peer1", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with no changes in update", + peerID: "peer2", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + { + name: "update message with changes in checks", + peerID: "peer3", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + Checks: []*proto.Checks{ + { + Files: []string{"/usr/bin/netbird"}, + }, + }, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with lower serial number", + peerID: "peer4", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPeersUpdateManager(nil) + ctx := context.Background() + + if tt.existingUpdate != nil { + p.peerUpdateMessage[tt.peerID] = tt.existingUpdate + } + + result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestIsNewPeerUpdateMessage(t *testing.T) { + t.Run("Unchanged value", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + }) + + t.Run("Unchanged value with serial incremented", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + }) + + t.Run("Updating routes network", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + + }) + + t.Run("Updating routes groups", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating network map peers", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newPeer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.4"), + SSHEnabled: true, + Key: "peer4-key", + DNSLabel: "peer4", + SSHKey: "peer4-ssh-key", + } + newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating process check", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + + newUpdateMessage2 := createMockUpdateMessage(t) + newUpdateMessage2.Update.NetworkMap.Serial++ + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.False(t, message) + + newUpdateMessage3 := createMockUpdateMessage(t) + newUpdateMessage3.Update.Checks = []*proto.Checks{} + newUpdateMessage3.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3) + assert.NoError(t, err) + assert.True(t, message) + + newUpdateMessage4 := createMockUpdateMessage(t) + check := &posture.Checks{ + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/local/netbird", + MacPath: "/usr/bin/netbird", + }, + }, + }, + }, + } + newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} + newUpdateMessage4.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4) + assert.NoError(t, err) + assert.True(t, message) + + newUpdateMessage5 := createMockUpdateMessage(t) + check = &posture.Checks{ + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", + MacPath: "/usr/local/netbird", + }, + }, + }, + }, + } + newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} + newUpdateMessage5.Update.NetworkMap.Serial++ + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating DNS configuration", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newDomain := "newexample.com" + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append( + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains, + newDomain, + ) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating peer IP", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating firewall rule", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Add new firewall rule", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newRule := &FirewallRule{ + PeerIP: "192.168.1.3", + Direction: firewallRuleDirectionOUT, + Action: string(PolicyTrafficActionDrop), + Protocol: string(PolicyRuleProtocolUDP), + Port: "53", + } + newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Removing nameserver", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating name server IP", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + + t.Run("Updating custom DNS zone", func(t *testing.T) { + newUpdateMessage1 := createMockUpdateMessage(t) + newUpdateMessage2 := createMockUpdateMessage(t) + + newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" + newUpdateMessage2.Update.NetworkMap.Serial++ + + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + assert.NoError(t, err) + assert.True(t, message) + }) + +} + +func createMockUpdateMessage(t *testing.T) *UpdateMessage { + t.Helper() + + _, ipNet, err := net.ParseCIDR("192.168.1.0/24") + if err != nil { + t.Fatal(err) + } + domainList, err := domain.FromStringList([]string{"example.com"}) + if err != nil { + t.Fatal(err) + } + + config := &Config{ + Signal: &Host{ + Proto: "https", + URI: "signal.uri", + Username: "", + Password: "", + }, + Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, + TURNConfig: &TURNConfig{ + Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, + }, + } + peer := &nbpeer.Peer{ + IP: net.ParseIP("192.168.1.1"), + SSHEnabled: true, + Key: "peer-key", + DNSLabel: "peer1", + SSHKey: "peer1-ssh-key", + } + + secretManager := NewTimeBasedAuthSecretsManager( + NewPeersUpdateManager(nil), + &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{ + Duration: defaultDuration, + }, + Secret: "secret", + Turns: []*Host{TurnTestHost}, + }, + &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "secret", + }, + ) + + networkMap := &NetworkMap{ + Network: &Network{Net: *ipNet, Serial: 1000}, + Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, + OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, + Routes: []*nbroute.Route{ + { + ID: "route1", + Network: netip.MustParsePrefix("10.0.0.0/24"), + KeepRoute: true, + NetID: "route1", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + { + ID: "route2", + Domains: domainList, + KeepRoute: true, + NetID: "route2", + Peer: "peer1", + NetworkType: 1, + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{"test1", "test2"}, + }, + }, + DNSConfig: nbdns.Config{ + ServiceEnable: true, + NameServerGroups: []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + { + ID: "ns1", + NameServers: []nbdns.NameServer{{ + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, + }}, + Groups: []string{"group1"}, + Primary: true, + Domains: []string{"example.com"}, + Enabled: true, + SearchDomainsEnabled: true, + }, + }, + CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, + }, + FirewallRules: []*FirewallRule{ + {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + }, + } + dnsName := "example.com" + checks := []*posture.Checks{ + { + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + { + LinuxPath: "/usr/bin/netbird", + WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", + MacPath: "/usr/bin/netbird", + }, + }, + }, + }, + }, + } + dnsCache := &DNSConfigCache{} + + turnToken, err := secretManager.GenerateTurnToken() + if err != nil { + t.Fatal(err) + } + + relayToken, err := secretManager.GenerateRelayToken() + if err != nil { + t.Fatal(err) + } + + return &UpdateMessage{ + Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), + NetworkMap: networkMap, + } +} diff --git a/management/server/user.go b/management/server/user.go index 71608ef20e1..9fdd3a6eeea 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" @@ -473,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init } func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err } @@ -485,15 +486,22 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { peers, err := account.FindUserPeers(targetUserID) if err != nil { - return status.Errorf(status.Internal, "failed to find user peers") + return false, status.Errorf(status.Internal, "failed to find user peers") + } + + hadPeers := len(peers) > 0 + if !hadPeers { + return false, nil } peerIDs := make([]string, 0, len(peers)) @@ -501,7 +509,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU peerIDs = append(peerIDs, peer.ID) } - return am.deletePeers(ctx, account, peerIDs, initiatorUserID) + return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -745,6 +753,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, updatedUsers := make([]*UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer + userIDs []string eventsToStore []func() ) @@ -753,6 +762,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } + userIDs = append(userIDs, update.Id) + oldUser := account.Users[update.Id] if oldUser == nil { if !addIfNotExists { @@ -816,7 +827,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, err } - if account.Settings.GroupsPropagationEnabled { + if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { am.updateAccountPeers(ctx, account) } @@ -1167,7 +1178,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") } - var allErrors error + var ( + allErrors error + updateAccountPeers bool + ) deletedUsersMeta := make(map[string]map[string]any) for _, targetUserID := range targetUserIDs { @@ -1193,12 +1207,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err)) continue } + if hadPeers { + updateAccountPeers = true + } + delete(account.Users, targetUserID) deletedUsersMeta[targetUserID] = meta } @@ -1208,7 +1226,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return fmt.Errorf("failed to delete users: %w", err) } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } for targetUserID, meta := range deletedUsersMeta { am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) @@ -1217,11 +1237,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) - return nil, err + return nil, false, err } if !isNil(am.idpManager) { @@ -1232,16 +1252,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) if err != nil { log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) - return nil, err + return nil, false, err } } else { log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) } } - err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) if err != nil { - return nil, err + return nil, false, err } u, err := account.FindUser(targetUserID) @@ -1254,7 +1274,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun tuCreatedAt = u.CreatedAt } - return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil + return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. @@ -1333,3 +1353,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } return nil, false } + +// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. +func areUsersLinkedToPeers(account *Account, userIDs []string) bool { + for _, peer := range account.Peers { + if slices.Contains(userIDs, peer.UserID) { + return true + } + } + return false +} diff --git a/management/server/user_test.go b/management/server/user_test.go index 1a5704551bc..d4f560a54c7 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,9 +10,12 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" @@ -1264,3 +1267,165 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { }) } } + +func TestUserAccountPeersUpdate(t *testing.T) { + // account groups propagation is enabled + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + require.NoError(t, err) + + policy := Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Creating a new regular user should not update account peers and not send peer update + t.Run("creating new regular user with no groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleUser, + Issued: UserIssuedAPI, + }, true) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // updating user with no linked peers should not update account peers and not send peer update + t.Run("updating user with no linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleUser, + Issued: UserIssuedAPI, + }, false) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // deleting user with no linked peers should not update account peers and not send peer update + t.Run("deleting user with no linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser1") + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // create a user and add new peer with the user + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + }, true) + require.NoError(t, err) + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + // updating user with linked peers should update account peers and send peer update + t.Run("updating user with linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser2", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + }, false) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + }) + + // deleting user with linked peers should update account peers and send peer update + t.Run("deleting user with linked peers", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, peer4UpdMsg) + close(done) + }() + + err = manager.DeleteUser(context.Background(), account.Id, userID, "regularUser2") + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) +} From 563dca705cc30250866438aae33942f4efcd3e17 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 23 Oct 2024 16:40:15 +0200 Subject: [PATCH 56/81] [management] Fix session inactivity response (#2770) --- management/server/http/accounts_handler.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 4d4066de487..4baf9c6925f 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -137,13 +137,15 @@ func toAccountResponse(accountID string, settings *server.Settings) *api.Account } apiSettings := api.AccountSettings{ - PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), - PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, - GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, - JwtGroupsEnabled: &settings.JWTGroupsEnabled, - JwtGroupsClaimName: &settings.JWTGroupsClaimName, - JwtAllowGroups: &jwtAllowGroups, - RegularUsersViewBlocked: settings.RegularUsersViewBlocked, + PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), + PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, + PeerInactivityExpiration: int(settings.PeerInactivityExpiration.Seconds()), + PeerInactivityExpirationEnabled: settings.PeerInactivityExpirationEnabled, + GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, + JwtGroupsEnabled: &settings.JWTGroupsEnabled, + JwtGroupsClaimName: &settings.JWTGroupsClaimName, + JwtAllowGroups: &jwtAllowGroups, + RegularUsersViewBlocked: settings.RegularUsersViewBlocked, } if settings.Extra != nil { From 44f2ce666ec9debf8c88daa1e5c66235efcbf063 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 23 Oct 2024 18:32:27 +0200 Subject: [PATCH 57/81] [relay-client] Log exposed address (#2771) * Log exposed address --- relay/client/client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/relay/client/client.go b/relay/client/client.go index 90bc3ac418f..20a73f4b343 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -184,12 +184,14 @@ func (c *Client) Connect() error { return err } + c.log = c.log.WithField("relay", c.instanceURL.String()) + c.log.Infof("relay connection established") + c.serviceIsRunning = true c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - c.log.Infof("relay connection established") return nil } From 869537c9511edbaf9f22bbcbffee5a2bcbbc3b93 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:53:46 +0200 Subject: [PATCH 58/81] [client] Cleanup dns and route states on startup (#2757) --- client/firewall/nftables/state.go | 1 + client/internal/connect.go | 12 +- client/internal/dns/consts_freebsd.go | 3 +- client/internal/dns/consts_linux.go | 3 +- client/internal/dns/file_repair_unix.go | 8 +- client/internal/dns/file_repair_unix_test.go | 9 +- client/internal/dns/file_unix.go | 24 +- client/internal/dns/host.go | 19 +- client/internal/dns/host_android.go | 12 +- client/internal/dns/host_darwin.go | 18 +- client/internal/dns/host_ios.go | 11 +- client/internal/dns/host_unix.go | 28 +- client/internal/dns/host_windows.go | 20 +- client/internal/dns/network_manager_unix.go | 21 +- client/internal/dns/resolvconf_unix.go | 23 +- client/internal/dns/server.go | 40 ++- client/internal/dns/server_test.go | 10 +- client/internal/dns/server_windows.go | 2 +- client/internal/dns/systemd_freebsd.go | 2 +- client/internal/dns/systemd_linux.go | 21 +- .../internal/dns/unclean_shutdown_android.go | 5 - .../internal/dns/unclean_shutdown_darwin.go | 48 +-- client/internal/dns/unclean_shutdown_ios.go | 5 - .../internal/dns/unclean_shutdown_mobile.go | 14 + client/internal/dns/unclean_shutdown_unix.go | 81 ++--- .../internal/dns/unclean_shutdown_windows.go | 69 +--- client/internal/engine.go | 35 +- client/internal/routemanager/manager.go | 15 +- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 9 +- .../internal/routemanager/systemops/state.go | 81 +++++ .../systemops/systemops_android.go | 9 +- .../systemops/systemops_generic.go | 51 ++- .../systemops/systemops_generic_test.go | 8 +- .../routemanager/systemops/systemops_ios.go | 25 +- .../routemanager/systemops/systemops_linux.go | 13 +- .../routemanager/systemops/systemops_unix.go | 9 +- .../systemops/systemops_windows.go | 9 +- client/internal/statemanager/manager.go | 298 ++++++++++++++++++ client/internal/statemanager/path.go | 35 ++ client/ios/NetBirdSDK/client.go | 4 +- client/server/server.go | 47 +++ 42 files changed, 785 insertions(+), 376 deletions(-) create mode 100644 client/firewall/nftables/state.go delete mode 100644 client/internal/dns/unclean_shutdown_android.go delete mode 100644 client/internal/dns/unclean_shutdown_ios.go create mode 100644 client/internal/dns/unclean_shutdown_mobile.go create mode 100644 client/internal/routemanager/systemops/state.go create mode 100644 client/internal/statemanager/manager.go create mode 100644 client/internal/statemanager/path.go diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go new file mode 100644 index 00000000000..7027fe98719 --- /dev/null +++ b/client/firewall/nftables/state.go @@ -0,0 +1 @@ +package nftables diff --git a/client/internal/connect.go b/client/internal/connect.go index 74dc1f1b56d..13f10fbf1e6 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -117,12 +117,6 @@ func (c *ConnectClient) run( log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) - // Check if client was not shut down in a clean way and restore DNS config if required. - // Otherwise, we might not be able to connect to the management server to retrieve new config. - if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { - log.Errorf("checking unclean shutdown error: %s", err) - } - backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -358,7 +352,11 @@ func (c *ConnectClient) Stop() error { if c.engine == nil { return nil } - return c.engine.Stop() + if err := c.engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } + + return nil } func (c *ConnectClient) isContextCancelled() bool { diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go index 958eca8e55b..64c8fe5ebed 100644 --- a/client/internal/dns/consts_freebsd.go +++ b/client/internal/dns/consts_freebsd.go @@ -1,6 +1,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" ) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go index 32456a50fee..15614b0c599 100644 --- a/client/internal/dns/consts_linux.go +++ b/client/internal/dns/consts_linux.go @@ -3,6 +3,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" ) diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index ae2c33b8684..9a9218fa1f0 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -9,6 +9,8 @@ import ( "github.com/fsnotify/fsnotify" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) var ( @@ -20,7 +22,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf) error +type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error type repair struct { operationFile string @@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { if f.inotify != nil { return } @@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) } - err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf) + err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager) if err != nil { log.Errorf("failed to repair resolv.conf: %v", err) } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index 4dba79e996d..e948557b661 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/util" ) @@ -104,14 +105,14 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(operationFile, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) if err != nil { @@ -151,14 +152,14 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(tmpLink, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) if err != nil { diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 624e089cb48..02ae26e10e3 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -11,6 +11,8 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -36,7 +38,7 @@ type fileConfigurator struct { nbNameserverIP string } -func newFileConfigurator() (hostManager, error) { +func newFileConfigurator() (*fileConfigurator, error) { fc := &fileConfigurator{} fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) return fc, nil @@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool { return false } -func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { backupFileExist := f.isBackupFileExist() if !config.RouteAll { if backupFileExist { @@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { f.repair.stopWatchFileChanges() - err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) + err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) if err != nil { return err } - f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) + f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager) return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error { +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) @@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) // create another backup for unclean shutdown detection right after overwriting the original resolv.conf - if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil { + if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil { log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) } @@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return os.RemoveAll(fileDefaultResolvConfBackupLocation) } @@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add return restoreResolvConfFile() } - log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring") + log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress) return nil } @@ -192,10 +190,6 @@ func restoreResolvConfFile() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err) - } - return nil } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e55a0705556..e2b5f699a7d 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -5,14 +5,14 @@ import ( "net/netip" "strings" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) type hostManager interface { - applyDNSConfig(config HostDNSConfig) error + applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error supportCustomPort() bool - restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error } type SystemDNSSettings struct { @@ -35,15 +35,15 @@ type DomainConfig struct { } type mockHostConfigurator struct { - applyDNSConfigFunc func(config HostDNSConfig) error + applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNSFunc func() error supportCustomPortFunc func() bool restoreUncleanShutdownDNSFunc func(*netip.Addr) error } -func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { if m.applyDNSConfigFunc != nil { - return m.applyDNSConfigFunc(config) + return m.applyDNSConfigFunc(config, stateManager) } return fmt.Errorf("method applyDNSSettings is not implemented") } @@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool { return false } -func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { - if m.restoreUncleanShutdownDNSFunc != nil { - return m.restoreUncleanShutdownDNSFunc(storedDNSAddress) - } - return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented") -} - func newNoopHostMocker() hostManager { return &mockHostConfigurator{ - applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, + applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, restoreHostDNSFunc: func() error { return nil }, supportCustomPortFunc: func() bool { return true }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index 9230cb257f4..5653710d705 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,15 +1,17 @@ package dns -import "net/netip" +import ( + "github.com/netbirdio/netbird/client/internal/statemanager" +) type androidHostManager struct { } -func newHostManager() (hostManager, error) { +func newHostManager() (*androidHostManager, error) { return &androidHostManager{}, nil } -func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { return nil } @@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error { func (a androidHostManager) supportCustomPort() bool { return false } - -func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 5dee305c2ed..b8ba33e342c 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -8,12 +8,13 @@ import ( "fmt" "io" "net" - "net/netip" "os/exec" "strconv" "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -37,7 +38,7 @@ type systemConfigurator struct { systemDNSSettings SystemDNSSettings } -func newHostManager() (hostManager, error) { +func newHostManager() (*systemConfigurator, error) { return &systemConfigurator{ createdKeys: make(map[string]struct{}), }, nil @@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool { return true } -func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error { } } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) { return primaryService, router, nil } -func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (s *systemConfigurator) restoreUncleanShutdownDNS() error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via scutil: %w", err) } diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index ad8b14fb8d6..4a0acf57241 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,9 +3,10 @@ package dns import ( "encoding/json" "fmt" - "net/netip" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type iosHostManager struct { @@ -13,13 +14,13 @@ type iosHostManager struct { config HostDNSConfig } -func newHostManager(dnsManager IosDnsManager) (hostManager, error) { +func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { return &iosHostManager{ dnsManager: dnsManager, }, nil } -func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { return fmt.Errorf("marshal: %w", err) @@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error { func (a iosHostManager) supportCustomPort() bool { return false } - -func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 72b8f6c6e6b..7bd4aec6482 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -4,9 +4,9 @@ package dns import ( "bufio" - "errors" "fmt" "io" + "net/netip" "os" "strings" @@ -21,27 +21,8 @@ const ( resolvConfManager ) -var ErrUnknownOsManagerType = errors.New("unknown os manager type") - type osManagerType int -func newOsManagerType(osManager string) (osManagerType, error) { - switch osManager { - case "netbird": - return fileManager, nil - case "file": - return netbirdManager, nil - case "networkManager": - return networkManager, nil - case "systemd": - return systemdManager, nil - case "resolvconf": - return resolvConfManager, nil - default: - return 0, ErrUnknownOsManagerType - } -} - func (t osManagerType) String() string { switch t { case netbirdManager: @@ -59,6 +40,11 @@ func (t osManagerType) String() string { } } +type restoreHostManager interface { + hostManager + restoreUncleanShutdownDNS(*netip.Addr) error +} + func newHostManager(wgInterface string) (hostManager, error) { osManager, err := getOSDNSManagerType() if err != nil { @@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) { return newHostManagerFromType(wgInterface, osManager) } -func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) { +func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { switch osManager { case networkManager: return newNetworkManagerDbusConfigurator(wgInterface) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index c8bf2e55237..7ecca8a41f4 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -3,11 +3,12 @@ package dns import ( "fmt" "io" - "net/netip" "strings" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -31,7 +32,7 @@ type registryConfigurator struct { routingAll bool } -func newHostManager(wgInterface WGIface) (hostManager, error) { +func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { guid, err := wgInterface.GetInterfaceGUIDString() if err != nil { return nil, err @@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) { return newHostManagerWithGuid(guid) } -func newHostManagerWithGuid(guid string) (hostManager, error) { +func newHostManagerWithGuid(guid string) (*registryConfigurator, error) { return ®istryConfigurator{ guid: guid, }, nil @@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool { return false } -func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if config.RouteAll { err = r.addDNSSetupForAll(config.ServerIP) @@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(r.guid); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error { return fmt.Errorf("remove interface registry key: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { return regKey, nil } -func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (r *registryConfigurator) restoreUncleanShutdownDNS() error { if err := r.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via registry: %w", err) } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 184047a643d..63bbead7728 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -16,6 +16,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbversion "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{ type networkManagerDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() { } } -func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) { +func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) { obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) if err != nil { return nil, fmt.Errorf("get nm dbus: %w", err) @@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) return &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool { return false } -func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { connSettings, configVersion, err := n.getAppliedConnectionSettings() if err != nil { return fmt.Errorf("retrieving the applied connection settings, error: %w", err) @@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for network-manager restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: networkManager, + WgIface: n.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("delete connection settings: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return nil } diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 0c17626c7a9..a5d1cc8a225 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -9,6 +9,8 @@ import ( "os/exec" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const resolvconfCommand = "resolvconf" @@ -22,7 +24,7 @@ type resolvconf struct { } // supported "openresolv" only -func newResolvConfConfigurator(wgInterface string) (hostManager, error) { +func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { resolvConfEntries, err := parseDefaultResolvConf() if err != nil { log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) @@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool { return false } -func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { +func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if !config.RouteAll { err = r.restoreHostDNS() @@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { append([]string{config.ServerIP}, r.originalNameServers...), options) - // create a backup for unclean shutdown detection before the resolv.conf is changed - if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: resolvConfManager, + WgIface: r.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } err = r.applyConfig(buf) @@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error { cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) _, err := cmd.Output() if err != nil { - return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err) - } - - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) + return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil @@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { cmd.Stdin = &content _, err := cmd.Output() if err != nil { - return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err) + return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a4651ebb5b0..772797fac0a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,6 +7,7 @@ import ( "runtime" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -14,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -63,6 +65,7 @@ type DefaultServer struct { iosDnsManager IosDnsManager statusRecorder *peer.Status + stateManager *statemanager.Manager } type handlerWithStop interface { @@ -77,12 +80,7 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -99,7 +97,7 @@ func NewDefaultServer( dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream( statusRecorder *peer.Status, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -130,12 +128,12 @@ func NewDefaultServerIos( iosDnsManager IosDnsManager, statusRecorder *peer.Status, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { +func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, @@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi }, wgInterface: wgInterface, statusRecorder: statusRecorder, + stateManager: stateManager, hostsDNSHolder: newHostsDNSHolder(), } @@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) { } } + s.stateManager.RegisterState(&ShutdownState{}) s.hostManager, err = s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) @@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() { s.ctxCancel() if s.hostManager != nil { - err := s.hostManager.restoreHostDNS() - if err != nil { - log.Error(err) + if err := s.hostManager.restoreHostDNS(); err != nil { + log.Error("failed to restore host DNS settings: ", err) + } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete shutdown dns state: %v", err) } } @@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { hostUpdate.RouteAll = false } - if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { + if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { log.Error(err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) } @@ -521,7 +529,7 @@ func (s *DefaultServer) upstreamCallbacks( } } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } @@ -551,7 +559,7 @@ func (s *DefaultServer) upstreamCallbacks( s.currentConfig.RouteAll = true s.service.RegisterMux(nbdns.RootZone, handler) } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 4a5aff3eaed..21f1f1b7dde 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" @@ -291,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Fatal(err) } @@ -400,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -495,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) if err != nil { t.Fatalf("%v", err) } @@ -554,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ + ctx: context.Background(), service: NewServiceViaMemory(&mocWGIface{}), localResolver: &localResolver{ registeredMap: make(registrationMap), @@ -570,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { } var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { + hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { domains := []string{} for _, item := range config.Domains { if item.Disabled { diff --git a/client/internal/dns/server_windows.go b/client/internal/dns/server_windows.go index 5e1494e9ef8..bc051d59bc6 100644 --- a/client/internal/dns/server_windows.go +++ b/client/internal/dns/server_windows.go @@ -1,5 +1,5 @@ package dns -func (s *DefaultServer) initialize() (manager hostManager, err error) { +func (s *DefaultServer) initialize() (hostManager, error) { return newHostManager(s.wgInterface) } diff --git a/client/internal/dns/systemd_freebsd.go b/client/internal/dns/systemd_freebsd.go index 0de805337d9..41c8bf019bb 100644 --- a/client/internal/dns/systemd_freebsd.go +++ b/client/internal/dns/systemd_freebsd.go @@ -7,7 +7,7 @@ import ( var errNotImplemented = errors.New("not implemented") -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(string) (restoreHostManager, error) { return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index e2fa5b71ae3..a031be5823d 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -38,6 +39,7 @@ const ( type systemdDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct { MatchOnly bool } -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) { iface, err := net.InterfaceByName(wgInterface) if err != nil { return nil, fmt.Errorf("get interface: %w", err) @@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { return &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { return true } -func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { parsedIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %w", err) @@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for systemd restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: systemdManager, + WgIface: s.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("unable to revert link configuration, got error: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return s.flushCaches() } diff --git a/client/internal/dns/unclean_shutdown_android.go b/client/internal/dns/unclean_shutdown_android.go deleted file mode 100644 index 105fb00bf41..00000000000 --- a/client/internal/dns/unclean_shutdown_android.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index e077ec84d30..9bbdd2b566e 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -3,57 +3,25 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - log "github.com/sirupsen/logrus" ) -const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns" - -func CheckUncleanShutdown(string) error { - if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } +type ShutdownState struct { +} - log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation) +func (s *ShutdownState) Name() string { + return "dns_state" +} +func (s *ShutdownState) Cleanup() error { manager, err := newHostManager() if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) - } - - return nil -} - -func createUncleanShutdownIndicator() error { - dir := filepath.Dir(fileUncleanShutdownFileLocation) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err) - } - return nil -} diff --git a/client/internal/dns/unclean_shutdown_ios.go b/client/internal/dns/unclean_shutdown_ios.go deleted file mode 100644 index 105fb00bf41..00000000000 --- a/client/internal/dns/unclean_shutdown_ios.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_mobile.go b/client/internal/dns/unclean_shutdown_mobile.go new file mode 100644 index 00000000000..0d3a2cdbde7 --- /dev/null +++ b/client/internal/dns/unclean_shutdown_mobile.go @@ -0,0 +1,14 @@ +//go:build ios || android + +package dns + +type ShutdownState struct { +} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index 8a32090c34d..fcf60c6945c 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -3,66 +3,44 @@ package dns import ( - "errors" "fmt" - "io/fs" "net/netip" "os" "path/filepath" - "strings" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" ) -func CheckUncleanShutdown(wgIface string) error { - if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation) - - managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation) - if err != nil { - return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - - managerFields := strings.Split(string(managerData), ",") - if len(managerFields) < 2 { - return errors.New("split manager data: insufficient number of fields") - } - osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1] - - dnsAddress, err := netip.ParseAddr(dnsAddressStr) - if err != nil { - return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err) - } - - log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr) +type ShutdownState struct { + ManagerType osManagerType + DNSAddress netip.Addr + WgIface string +} - // determine os manager type, so we can invoke the respective restore action - osManagerType, err := newOsManagerType(osManagerTypeStr) - if err != nil { - return fmt.Errorf("detect previous host manager: %w", err) - } +func (s *ShutdownState) Name() string { + return "dns_state" +} - manager, err := newHostManagerFromType(wgIface, osManagerType) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerFromType(s.WgIface, s.ManagerType) if err != nil { return fmt.Errorf("create previous host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } -func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error { +// TODO: move file contents to state manager +func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { + dnsAddress, err := netip.ParseAddr(dnsAddressStr) + if err != nil { + return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) + } + dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return fmt.Errorf("create dir %s: %w", dir, err) @@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType return fmt.Errorf("create %s: %w", sourcePath, err) } - managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress) - - if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - return nil -} - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err) + state := &ShutdownState{ + ManagerType: fileManager, + DNSAddress: dnsAddress, } - if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err) + if err := stateManager.UpdateState(state); err != nil { + return fmt.Errorf("update state: %w", err) } + return nil } diff --git a/client/internal/dns/unclean_shutdown_windows.go b/client/internal/dns/unclean_shutdown_windows.go index 41db46768c5..74e40cc1153 100644 --- a/client/internal/dns/unclean_shutdown_windows.go +++ b/client/internal/dns/unclean_shutdown_windows.go @@ -1,75 +1,26 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - "github.com/sirupsen/logrus" -) - -const ( - netbirdProgramDataLocation = "Netbird" - fileUncleanShutdownFile = "unclean_shutdown_dns.txt" ) -func CheckUncleanShutdown(string) error { - file := getUncleanShutdownFile() - - if _, err := os.Stat(file); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file) +type ShutdownState struct { + Guid string +} - guid, err := os.ReadFile(file) - if err != nil { - return fmt.Errorf("read %s: %w", file, err) - } +func (s *ShutdownState) Name() string { + return "dns_state" +} - manager, err := newHostManagerWithGuid(string(guid)) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerWithGuid(s.Guid) if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func createUncleanShutdownIndicator(guid string) error { - file := getUncleanShutdownFile() - - dir := filepath.Dir(file) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(file, []byte(guid), 0600); err != nil { - return fmt.Errorf("create %s: %w", file, err) - } - - return nil -} - -func removeUncleanShutdownIndicator() error { - file := getUncleanShutdownFile() - - if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", file, err) - } - return nil -} - -func getUncleanShutdownFile() string { - return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile) -} diff --git a/client/internal/engine.go b/client/internal/engine.go index 459518de136..22dd1f584a7 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,18 +23,19 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" - - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -166,6 +167,7 @@ type Engine struct { checks []*mgmProto.Checks relayManager *relayClient.Manager + stateManager *statemanager.Manager } // Peer is an instance of the Connection Peer @@ -213,7 +215,7 @@ func NewEngineWithProbes( probes *ProbeHolder, checks []*mgmProto.Checks, ) *Engine { - return &Engine{ + engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, signal: signalClient, @@ -232,6 +234,11 @@ func NewEngineWithProbes( probes: probes, checks: checks, } + if path := statemanager.GetDefaultStatePath(); path != "" { + engine.stateManager = statemanager.New(path) + } + + return engine } func (e *Engine) Stop() error { @@ -253,7 +260,7 @@ func (e *Engine) Stop() error { e.stopDNSServer() if e.routeManager != nil { - e.routeManager.Stop() + e.routeManager.Stop(e.stateManager) } err := e.removeAllPeers() @@ -275,6 +282,17 @@ func (e *Engine) Stop() error { e.close() log.Infof("stopped Netbird Engine") + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := e.stateManager.Stop(ctx); err != nil { + return fmt.Errorf("failed to stop state manager: %w", err) + } + if err := e.stateManager.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + return nil } @@ -314,6 +332,8 @@ func (e *Engine) Start() error { } } + e.stateManager.Start() + initialRoutes, dnsServer, err := e.newDnsServer() if err != nil { e.close() @@ -322,7 +342,7 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() + beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { @@ -1219,10 +1239,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) return nil, dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager) if err != nil { return nil, nil, err } + return nil, dnsServer, nil } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d7ddf7ae8b7..0a1c7dc56b8 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,14 +32,14 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error - Stop() + Stop(stateManager *statemanager.Manager) } // DefaultManager is the default instance of a route manager @@ -120,12 +121,12 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(nil); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -136,7 +137,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } @@ -154,7 +155,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop stops the manager watchers and clean firewall rules -func (m *DefaultManager) Stop() { +func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() if m.serverRouter != nil { m.serverRouter.cleanUp() @@ -172,7 +173,7 @@ func (m *DefaultManager) Stop() { } if !nbnet.CustomRoutingDisabled() { - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 044a996c777..e669bc44a08 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -426,10 +426,10 @@ func TestManagerUpdateRoutes(t *testing.T) { ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) - _, _, err = routeManager.Init() + _, _, err = routeManager.Init(nil) require.NoError(t, err, "should init route manager") - defer routeManager.Stop() + defer routeManager.Stop(nil) if testCase.removeSrvRouter { routeManager.serverRouter = nil diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 908279c885a..503185f0311 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) @@ -17,10 +18,10 @@ type MockManager struct { UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func() + StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } @@ -65,8 +66,8 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop mock implementation of Stop from Manager interface -func (m *MockManager) Stop() { +func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { - m.StopFunc() + m.StopFunc(stateManager) } } diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go new file mode 100644 index 00000000000..26992467750 --- /dev/null +++ b/client/internal/routemanager/systemops/state.go @@ -0,0 +1,81 @@ +package systemops + +import ( + "encoding/json" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +type RouteEntry struct { + Prefix netip.Prefix `json:"prefix"` + Nexthop Nexthop `json:"nexthop"` +} + +type ShutdownState struct { + Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` + mu sync.RWMutex +} + +func NewShutdownState() *ShutdownState { + return &ShutdownState{ + Routes: make(map[netip.Prefix]RouteEntry), + } +} + +func (s *ShutdownState) Name() string { + return "route_state" +} + +func (s *ShutdownState) Cleanup() error { + sysops := NewSysOps(nil, nil) + var merr *multierror.Error + + s.mu.RLock() + defer s.mu.RUnlock() + + for _, route := range s.Routes { + if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) { + s.mu.Lock() + defer s.mu.Unlock() + + s.Routes[prefix] = RouteEntry{ + Prefix: prefix, + Nexthop: nexthop, + } +} + +func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.Routes, prefix) +} + +// MarshalJSON ensures that empty routes are marshaled as null +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if len(s.Routes) == 0 { + return json.Marshal(nil) + } + + return json.Marshal(s.Routes) +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &s.Routes) +} diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index 5e97a4a5f53..ca8aea3fbce 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -9,14 +9,15 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { return nil } @@ -28,6 +29,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error { return nil } +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 9258f4a4e3b..2b8a14ea2d2 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + stateManager.RegisterState(&ShutdownState{}) + initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) @@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn // These errors are not critical, but also we should not track and try to remove the routes either. return nexthop, refcounter.ErrIgnore } + + r.updateState(stateManager, prefix, nexthop) + return nexthop, err }, - r.removeFromRouteTable, + func(prefix netip.Prefix, nexthop Nexthop) error { + // remove from state even if we have trouble removing it from the route table + // it could be already gone + r.removeFromState(stateManager, prefix) + + return r.removeFromRouteTable(prefix, nexthop) + }, ) r.refCounter = refCounter @@ -63,7 +75,25 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn return r.setupHooks(initAddresses) } -func (r *SysOps) cleanupRefCounter() error { +func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { + state := getState(stateManager) + state.UpdateRoute(prefix, nexthop) + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + +func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) { + state := getState(stateManager) + state.RemoveRoute(prefix) + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("Failed to update state: %v", err) + } +} + +func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { if r.refCounter == nil { return nil } @@ -76,6 +106,10 @@ func (r *SysOps) cleanupRefCounter() error { return fmt.Errorf("flush route manager: %w", err) } + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete state: %v", err) + } + return nil } @@ -506,3 +540,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } + +func getState(stateManager *statemanager.Manager) *ShutdownState { + var shutdownState *ShutdownState + if state := stateManager.GetState(shutdownState); state != nil { + shutdownState = state.(*ShutdownState) + } else { + shutdownState = NewShutdownState() + } + + return shutdownState +} diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index ce5b6b8431b..5b7b13f97f8 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -77,10 +77,10 @@ func TestAddRemoveRoutes(t *testing.T) { r := NewSysOps(wgInterface, nil) - _, _, err = r.SetupRouting(nil) + _, _, err = r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) @@ -403,10 +403,10 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil) + _, _, err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 7cfb2b29895..bf06f373998 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -9,17 +9,18 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() @@ -46,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error { return nil } +func (r *SysOps) notify() { + prefixes := make([]netip.Prefix, 0, len(r.prefixes)) + for prefix := range r.prefixes { + prefixes = append(prefixes, prefix) + } + r.notifier.OnNewPrefixes(prefixes) +} + +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil @@ -54,11 +67,3 @@ func EnableIPForwarding() error { func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { return false, netip.Prefix{} } - -func (r *SysOps) notify() { - prefixes := make([]netip.Prefix, 0, len(r.prefixes)) - for prefix := range r.prefixes { - prefixes = append(prefixes, prefix) - } - r.notifier.OnNewPrefixes(prefixes) -} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 2d0c5782697..0124fd95e85 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -18,6 +18,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -85,10 +86,10 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if isLegacy() { log.Infof("Using legacy routing setup") - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } if err = addRoutingTableName(); err != nil { @@ -104,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb defer func() { if err != nil { - if cleanErr := r.CleanupRouting(); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") setIsLegacy(true) - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } @@ -128,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { if isLegacy() { - return r.cleanupRefCounter() + return r.cleanupRefCounter(stateManager) } var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index a2bbf35cf09..0f8f2a34175 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -13,15 +13,16 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 3f756788e70..b1732a08001 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -130,12 +131,12 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go new file mode 100644 index 00000000000..a5a14f807a2 --- /dev/null +++ b/client/internal/statemanager/manager.go @@ -0,0 +1,298 @@ +package statemanager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +// State interface defines the methods that all state types must implement +type State interface { + Name() string + Cleanup() error +} + +// Manager handles the persistence and management of various states +type Manager struct { + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} + + filePath string + // holds the states that are registered with the manager and that are to be persisted + states map[string]State + // holds the state names that have been updated and need to be persisted with the next save + dirty map[string]struct{} + // holds the type information for each registered state + stateTypes map[string]reflect.Type +} + +// New creates a new Manager instance +func New(filePath string) *Manager { + return &Manager{ + filePath: filePath, + states: make(map[string]State), + dirty: make(map[string]struct{}), + stateTypes: make(map[string]reflect.Type), + } +} + +// Start starts the state manager periodic save routine +func (m *Manager) Start() { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + var ctx context.Context + ctx, m.cancel = context.WithCancel(context.Background()) + m.done = make(chan struct{}) + + go m.periodicStateSave(ctx) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + m.cancel() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } + } + + return nil +} + +// RegisterState registers a state with the manager but doesn't attempt to persist it. +// Pass an uninitialized state to register it. +func (m *Manager) RegisterState(state State) { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + name := state.Name() + m.states[name] = nil + m.stateTypes[name] = reflect.TypeOf(state).Elem() +} + +// GetState returns the state for the given type +func (m *Manager) GetState(state State) State { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.states[state.Name()] +} + +// UpdateState updates the state in the manager and marks it as dirty for the next save. +// The state will be replaced with the new one. +func (m *Manager) UpdateState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), state) +} + +// DeleteState removes the state from the manager and marks it as dirty for the next save. +// Pass an uninitialized state to delete it. +func (m *Manager) DeleteState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), nil) +} + +func (m *Manager) setState(name string, state State) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.states[name]; !exists { + return fmt.Errorf("state %s not registered", name) + } + + m.states[name] = state + m.dirty[name] = struct{}{} + + return nil +} + +func (m *Manager) periodicStateSave(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + defer close(m.done) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := m.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + } + } +} + +// PersistState persists the states that have been updated since the last save. +func (m *Manager) PersistState(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.dirty) == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + done := make(chan error, 1) + + go func() { + data, err := json.MarshalIndent(m.states, "", " ") + if err != nil { + done <- fmt.Errorf("marshal states: %w", err) + return + } + + // nolint:gosec + if err := os.WriteFile(m.filePath, data, 0640); err != nil { + done <- fmt.Errorf("write state file: %w", err) + return + } + + done <- nil + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + if err != nil { + return err + } + } + + log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + + clear(m.dirty) + + return nil +} + +// loadState loads the existing state from the state file +func (m *Manager) loadState() error { + data, err := os.ReadFile(m.filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + log.Debug("state file does not exist") + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + log.Warn("State file appears to be corrupted, attempting to delete it") + if err := os.Remove(m.filePath); err != nil { + log.Errorf("Failed to delete corrupted state file: %v", err) + } else { + log.Info("State file deleted") + } + return fmt.Errorf("unmarshal states: %w", err) + } + + var merr *multierror.Error + + for name, rawState := range rawStates { + stateType, ok := m.stateTypes[name] + if !ok { + merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) + continue + } + + if string(rawState) == "null" { + continue + } + + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) + continue + } + + m.states[name] = statePtr + log.Debugf("loaded state: %s", name) + } + + return nberrors.FormatErrorOrNil(merr) +} + +// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. +// If the cleanup is successful, the state is marked for deletion. +func (m *Manager) PerformCleanup() error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.loadState(); err != nil { + log.Warnf("Failed to load state during cleanup: %v", err) + } + + var merr *multierror.Error + for name, state := range m.states { + if state == nil { + // If no state was found in the state file, we don't mark the state dirty nor return an error + continue + } + + log.Infof("client was not shut down properly, cleaning up %s", name) + if err := state.Cleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + } else { + // mark for deletion on cleanup success + m.states[name] = nil + m.dirty[name] = struct{}{} + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go new file mode 100644 index 00000000000..64c5316d871 --- /dev/null +++ b/client/internal/statemanager/path.go @@ -0,0 +1,35 @@ +package statemanager + +import ( + "os" + "path/filepath" + "runtime" + + "github.com/sirupsen/logrus" +) + +// GetDefaultStatePath returns the path to the state file based on the operating system +// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +func GetDefaultStatePath() string { + var path string + + switch runtime.GOOS { + case "windows": + path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + case "darwin", "linux": + path = "/var/lib/netbird/state.json" + case "freebsd", "openbsd", "netbsd", "dragonfly": + path = "/var/db/netbird/state.json" + // ios/android don't need state + default: + return "" + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + return "" + } + + return path +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index dc13706bf3f..9d65bdbe080 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -138,12 +138,12 @@ func (c *Client) Stop() { c.ctxCancel() } -// ÏSetTraceLogLevel configure the logger to trace level +// SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) } -// getStatusDetails return with the list of the PeerInfos +// GetStatusDetails return with the list of the PeerInfos func (c *Client) GetStatusDetails() *StatusDetails { fullStatus := c.recorder.GetFullStatus() diff --git a/client/server/server.go b/client/server/server.go index 0a4c1813159..342f61b883f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -20,7 +21,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -39,6 +44,8 @@ const ( defaultMaxRetryInterval = 60 * time.Minute defaultMaxRetryTime = 14 * 24 * time.Hour defaultRetryMultiplier = 1.7 + + errRestoreResidualState = "failed to restore residual state: %v" ) // Server for service control. @@ -95,6 +102,10 @@ func (s *Server) Start() error { defer s.mutex.Unlock() state := internal.CtxGetState(s.rootCtx) + if err := restoreResidualState(s.rootCtx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -292,6 +303,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() + if err := restoreResidualState(ctx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + state := internal.CtxGetState(ctx) defer func() { status, err := state.Status() @@ -549,6 +564,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.mutex.Lock() defer s.mutex.Unlock() + if err := restoreResidualState(callerCtx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + state := internal.CtxGetState(s.rootCtx) // if current state contains any error, return it @@ -829,3 +848,31 @@ func sendTerminalNotification() error { return wallCmd.Wait() } + +// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + var merr *multierror.Error + + // register the states we are interested in restoring + // this will also allow each subsystem to record its own state + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} From 4e918e55ba0fe9cd87a3b3eccc658d1a7deeda0f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 24 Oct 2024 11:43:14 +0200 Subject: [PATCH 59/81] [client] Fix controller re-connection (#2758) Rethink the peer reconnection implementation --- .github/workflows/golang-test-linux.yml | 7 +- client/iface/bind/ice_bind.go | 1 - client/iface/wgproxy/ebpf/portlookup.go | 4 +- client/iface/wgproxy/ebpf/portlookup_test.go | 3 + client/iface/wgproxy/factory_kernel.go | 2 + .../iface/wgproxy/factory_kernel_freebsd.go | 3 + client/iface/wgproxy/factory_usp.go | 3 + client/iface/wgproxy/udp/proxy.go | 28 ++- client/internal/engine.go | 23 +- client/internal/engine_test.go | 3 + client/internal/peer/conn.go | 226 +++++------------- client/internal/peer/conn_monitor.go | 212 ---------------- client/internal/peer/conn_test.go | 16 +- client/internal/peer/guard/guard.go | 194 +++++++++++++++ client/internal/peer/guard/ice_monitor.go | 135 +++++++++++ client/internal/peer/guard/sr_watcher.go | 119 +++++++++ client/internal/peer/ice/agent.go | 89 +++++++ client/internal/peer/ice/config.go | 22 ++ .../peer/{env_config.go => ice/env.go} | 22 +- client/internal/peer/{ => ice}/stdnet.go | 2 +- .../internal/peer/{ => ice}/stdnet_android.go | 2 +- client/internal/peer/worker_ice.go | 107 +-------- client/internal/peer/worker_relay.go | 11 +- relay/client/client.go | 26 +- relay/client/guard.go | 20 ++ relay/client/manager.go | 21 ++ signal/client/client.go | 1 + signal/client/grpc.go | 14 ++ signal/client/mock.go | 22 +- 29 files changed, 814 insertions(+), 524 deletions(-) delete mode 100644 client/internal/peer/conn_monitor.go create mode 100644 client/internal/peer/guard/guard.go create mode 100644 client/internal/peer/guard/ice_monitor.go create mode 100644 client/internal/peer/guard/sr_watcher.go create mode 100644 client/internal/peer/ice/agent.go create mode 100644 client/internal/peer/ice/config.go rename client/internal/peer/{env_config.go => ice/env.go} (80%) rename client/internal/peer/{ => ice}/stdnet.go (94%) rename client/internal/peer/{ => ice}/stdnet_android.go (94%) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 9457d3a6621..b584f0ff68c 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -79,9 +79,6 @@ jobs: - name: check git status run: git --no-pager diff --exit-code - - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock @@ -98,7 +95,7 @@ jobs: run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal - name: Generate Peer Test bin - run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... + run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ - run: chmod +x *testing.bin @@ -106,7 +103,7 @@ jobs: run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 - name: Run Iface tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... - name: Run RouteManager tests in docker run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index ccdcc2cda30..a9c25950d00 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -143,7 +143,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { conn, ok := b.endpoints[ep.DstIP()] b.endpointsMu.Unlock() if !ok { - log.Infof("failed to find endpoint for %s", ep.DstIP()) return b.StdNetBind.Send(bufs, ep) } diff --git a/client/iface/wgproxy/ebpf/portlookup.go b/client/iface/wgproxy/ebpf/portlookup.go index 0e2c20c9911..fce8f1507c6 100644 --- a/client/iface/wgproxy/ebpf/portlookup.go +++ b/client/iface/wgproxy/ebpf/portlookup.go @@ -5,9 +5,9 @@ import ( "net" ) -const ( +var ( portRangeStart = 3128 - portRangeEnd = 3228 + portRangeEnd = portRangeStart + 100 ) type portLookup struct { diff --git a/client/iface/wgproxy/ebpf/portlookup_test.go b/client/iface/wgproxy/ebpf/portlookup_test.go index 92f4b8eee9f..a2e92fc7926 100644 --- a/client/iface/wgproxy/ebpf/portlookup_test.go +++ b/client/iface/wgproxy/ebpf/portlookup_test.go @@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) { func Test_portLookup_on_allocated(t *testing.T) { pl := portLookup{} + portRangeStart = 4128 + portRangeEnd = portRangeStart + 100 + allocatedPort, err := allocatePort(portRangeStart) if err != nil { t.Fatal(err) diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 32e96e34f2d..3ad7dc59dd9 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -22,9 +22,11 @@ func NewKernelFactory(wgPort int) *KernelFactory { ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) if err := ebpfProxy.Listen(); err != nil { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) return f } + log.Infof("WireGuard Proxy Factory will produce eBPF proxy") f.ebpfProxy = ebpfProxy return f } diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go index 7ac2f99a882..736944229fc 100644 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ b/client/iface/wgproxy/factory_kernel_freebsd.go @@ -1,6 +1,8 @@ package wgproxy import ( + log "github.com/sirupsen/logrus" + udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) @@ -10,6 +12,7 @@ type KernelFactory struct { } func NewKernelFactory(wgPort int) *KernelFactory { + log.Infof("WireGuard Proxy Factory will produce UDP proxy") f := &KernelFactory{ wgPort: wgPort, } diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index 99f5ada017a..e2d479331b7 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -1,6 +1,8 @@ package wgproxy import ( + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" ) @@ -10,6 +12,7 @@ type USPFactory struct { } func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { + log.Infof("WireGuard Proxy Factory will produce bind proxy") f := &USPFactory{ bind: iceBind, } diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 8bee099014e..200d961f3c8 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -2,14 +2,16 @@ package udp import ( "context" + "errors" "fmt" + "io" "net" "sync" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" + cerrors "github.com/netbirdio/netbird/client/errors" ) // WGUDPProxy proxies @@ -121,7 +123,7 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } - return errors.FormatErrorOrNil(result) + return cerrors.FormatErrorOrNil(result) } // proxyToRemote proxies from Wireguard to the RemoteKey @@ -160,18 +162,16 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) { func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { defer func() { if err := p.close(); err != nil { - log.Warnf("error in proxy to local loop: %s", err) + if !errors.Is(err, io.EOF) { + log.Warnf("error in proxy to local loop: %s", err) + } } }() buf := make([]byte, 1500) for { - n, err := p.remoteConn.Read(buf) + n, err := p.remoteConnRead(ctx, buf) if err != nil { - if ctx.Err() != nil { - return - } - log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) return } @@ -193,3 +193,15 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { } } } + +func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) { + n, err = p.remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err) + return + } + return +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 22dd1f584a7..af2817e6ed3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -30,6 +30,8 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -168,6 +170,7 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -263,6 +266,10 @@ func (e *Engine) Stop() error { e.routeManager.Stop(e.stateManager) } + if e.srWatcher != nil { + e.srWatcher.Close() + } + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) @@ -389,6 +396,18 @@ func (e *Engine) Start() error { return fmt.Errorf("initialize dns server: %w", err) } + iceCfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.UDPMuxDefault, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + + e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) + e.srWatcher.Start() + e.receiveSignalEvents() e.receiveManagementEvents() e.receiveProbeEvents() @@ -971,7 +990,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e LocalWgPort: e.config.WgPort, RosenpassPubKey: e.getRosenpassPubKey(), RosenpassAddr: e.getRosenpassAddr(), - ICEConfig: peer.ICEConfig{ + ICEConfig: icemaker.Config{ StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, @@ -981,7 +1000,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) if err != nil { return nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d0ba1fffcf1..0018af6df8f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -29,6 +29,8 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.ctx = ctx + engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) type testCase struct { name string diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 99acfde314e..56b772759a2 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/cenkalti/backoff/v4" "github.com/pion/ice/v3" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -18,6 +17,8 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" + "github.com/netbirdio/netbird/client/internal/peer/guard" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -32,8 +33,6 @@ const ( connPriorityRelay ConnPriority = 1 connPriorityICETurn ConnPriority = 1 connPriorityICEP2P ConnPriority = 2 - - reconnectMaxElapsedTime = 30 * time.Minute ) type WgConfig struct { @@ -63,7 +62,7 @@ type ConnConfig struct { RosenpassAddr string // ICEConfig ICE protocol configuration - ICEConfig ICEConfig + ICEConfig icemaker.Config } type WorkerCallbacks struct { @@ -106,16 +105,12 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - // for reconnection operations - iCEDisconnected chan bool - relayDisconnected chan bool - connMonitor *ConnMonitor - reconnectCh <-chan struct{} + guard *guard.Guard } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) @@ -126,28 +121,18 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu connLog := log.WithField("peer", config.Key) var conn = &Conn{ - log: connLog, - ctx: ctx, - ctxCancel: ctxCancel, - config: config, - statusRecorder: statusRecorder, - signaler: signaler, - relayManager: relayManager, - allowedIP: allowedIP, - allowedNet: allowedNet.String(), - statusRelay: NewAtomicConnStatus(), - statusICE: NewAtomicConnStatus(), - iCEDisconnected: make(chan bool, 1), - relayDisconnected: make(chan bool, 1), - } - - conn.connMonitor, conn.reconnectCh = NewConnMonitor( - signaler, - iFaceDiscover, - config, - conn.relayDisconnected, - conn.iCEDisconnected, - ) + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + signaler: signaler, + relayManager: relayManager, + allowedIP: allowedIP, + allowedNet: allowedNet.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + } rFns := WorkerRelayCallbacks{ OnConnReady: conn.relayConnectionIsReady, @@ -159,7 +144,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu OnStatusChanged: conn.onWorkerICEStateDisconnected, } - conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns) + ctrl := isController(config) + conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, relayManager, rFns) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) @@ -174,6 +160,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } + conn.guard = guard.NewGuard(connLog, ctrl, conn.isConnectedOnAllWay, config.Timeout, srWatcher) + go conn.handshaker.Listen() return conn, nil @@ -184,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // be used. func (conn *Conn) Open() { conn.log.Debugf("open connection to peer") + conn.mu.Lock() defer conn.mu.Unlock() conn.opened = true @@ -200,24 +189,19 @@ func (conn *Conn) Open() { conn.log.Warnf("error while updating the state err: %v", err) } - go conn.startHandshakeAndReconnect() + go conn.startHandshakeAndReconnect(conn.ctx) } -func (conn *Conn) startHandshakeAndReconnect() { - conn.waitInitialRandomSleepTime() +func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() if err != nil { conn.log.Errorf("failed to send initial offer: %v", err) } - go conn.connMonitor.Start(conn.ctx) - - if conn.workerRelay.IsController() { - conn.reconnectLoopWithRetry() - } else { - conn.reconnectLoopForOnDisconnectedEvent() - } + go conn.guard.Start(ctx) + go conn.listenGuardEvent(ctx) } // Close closes this peer Conn issuing a close event to the Conn closeCh @@ -316,104 +300,6 @@ func (conn *Conn) GetKey() string { return conn.config.Key } -func (conn *Conn) reconnectLoopWithRetry() { - // Give chance to the peer to establish the initial connection. - // With it, we can decrease to send necessary offer - select { - case <-conn.ctx.Done(): - return - case <-time.After(3 * time.Second): - } - - ticker := conn.prepareExponentTicker() - defer ticker.Stop() - time.Sleep(1 * time.Second) - - for { - select { - case t := <-ticker.C: - if t.IsZero() { - // in case if the ticker has been canceled by context then avoid the temporary loop - return - } - - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) - } - } else { - if conn.statusICE.Get() == StatusDisconnected { - conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) - } - } - - // checks if there is peer connection is established via relay or ice - if conn.isConnected() { - continue - } - - err := conn.handshaker.sendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - - case <-conn.reconnectCh: - ticker.Stop() - ticker = conn.prepareExponentTicker() - - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - } -} - -func (conn *Conn) prepareExponentTicker() *backoff.Ticker { - bo := backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 0.1, - Multiplier: 2, - MaxInterval: conn.config.Timeout, - MaxElapsedTime: reconnectMaxElapsedTime, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, conn.ctx) - - ticker := backoff.NewTicker(bo) - <-ticker.C // consume the initial tick what is happening right after the ticker has been created - - return ticker -} - -// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer -// when the connection is lost. It will try to establish a connection only once time if before the connection was established -// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not -// mean that to switch to it. We always force to use the higher priority connection. -func (conn *Conn) reconnectLoopForOnDisconnectedEvent() { - for { - select { - case changed := <-conn.relayDisconnected: - if !changed { - continue - } - conn.log.Debugf("Relay state changed, try to send new offer") - case changed := <-conn.iCEDisconnected: - if !changed { - continue - } - conn.log.Debugf("ICE state changed, try to send new offer") - case <-conn.ctx.Done(): - conn.log.Debugf("context is done, stop reconnect loop") - return - } - - err := conn.handshaker.SendOffer() - if err != nil { - conn.log.Errorf("failed to do handshake: %v", err) - } - } -} - // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { conn.mu.Lock() @@ -513,7 +399,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { changed := conn.statusICE.Get() != newState && newState != StatusConnecting conn.statusICE.Set(newState) - conn.notifyReconnectLoopICEDisconnected(changed) + conn.guard.SetICEConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -604,7 +490,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { changed := conn.statusRelay.Get() != StatusDisconnected conn.statusRelay.Set(StatusDisconnected) - conn.notifyReconnectLoopRelayDisconnected(changed) + conn.guard.SetRelayedConnDisconnected(changed) peerState := State{ PubKey: conn.config.Key, @@ -617,6 +503,20 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { } } +func (conn *Conn) listenGuardEvent(ctx context.Context) { + for { + select { + case <-conn.guard.Reconnect: + conn.log.Debugf("send offer to peer") + if err := conn.handshaker.SendOffer(); err != nil { + conn.log.Errorf("failed to send offer: %v", err) + } + case <-ctx.Done(): + return + } + } +} + func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { return conn.config.WgConfig.WgInterface.UpdatePeer( conn.config.WgConfig.RemoteKey, @@ -693,7 +593,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } } -func (conn *Conn) waitInitialRandomSleepTime() { +func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { minWait := 100 maxWait := 800 duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond @@ -702,7 +602,7 @@ func (conn *Conn) waitInitialRandomSleepTime() { defer timeout.Stop() select { - case <-conn.ctx.Done(): + case <-ctx.Done(): case <-timeout.C: } } @@ -731,11 +631,17 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusDisconnected } -func (conn *Conn) isConnected() bool { +func (conn *Conn) isConnectedOnAllWay() (connected bool) { conn.mu.Lock() defer conn.mu.Unlock() - if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { + defer func() { + if !connected { + conn.logTraceConnState() + } + }() + + if conn.statusICE.Get() == StatusDisconnected { return false } @@ -805,20 +711,6 @@ func (conn *Conn) removeWgPeer() error { return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) } -func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) { - select { - case conn.relayDisconnected <- changed: - default: - } -} - -func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) { - select { - case conn.iCEDisconnected <- changed: - default: - } -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { @@ -831,6 +723,18 @@ func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { } } +func (conn *Conn) logTraceConnState() { + if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { + conn.log.Tracef("connectivity guard check, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) + } else { + conn.log.Tracef("connectivity guard check, ice state: %s", conn.statusICE) + } +} + +func isController(config ConnConfig) bool { + return config.LocalKey > config.Key +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } diff --git a/client/internal/peer/conn_monitor.go b/client/internal/peer/conn_monitor.go deleted file mode 100644 index 75722c99011..00000000000 --- a/client/internal/peer/conn_monitor.go +++ /dev/null @@ -1,212 +0,0 @@ -package peer - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/pion/ice/v3" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/stdnet" -) - -const ( - signalerMonitorPeriod = 5 * time.Second - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second -) - -type ConnMonitor struct { - signaler *Signaler - iFaceDiscover stdnet.ExternalIFaceDiscover - config ConnConfig - relayDisconnected chan bool - iCEDisconnected chan bool - reconnectCh chan struct{} - currentCandidates []ice.Candidate - candidatesMu sync.Mutex -} - -func NewConnMonitor(signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, config ConnConfig, relayDisconnected, iCEDisconnected chan bool) (*ConnMonitor, <-chan struct{}) { - reconnectCh := make(chan struct{}, 1) - cm := &ConnMonitor{ - signaler: signaler, - iFaceDiscover: iFaceDiscover, - config: config, - relayDisconnected: relayDisconnected, - iCEDisconnected: iCEDisconnected, - reconnectCh: reconnectCh, - } - return cm, reconnectCh -} - -func (cm *ConnMonitor) Start(ctx context.Context) { - signalerReady := make(chan struct{}, 1) - go cm.monitorSignalerReady(ctx, signalerReady) - - localCandidatesChanged := make(chan struct{}, 1) - go cm.monitorLocalCandidatesChanged(ctx, localCandidatesChanged) - - for { - select { - case changed := <-cm.relayDisconnected: - if !changed { - continue - } - log.Debugf("Relay state changed, triggering reconnect") - cm.triggerReconnect() - - case changed := <-cm.iCEDisconnected: - if !changed { - continue - } - log.Debugf("ICE state changed, triggering reconnect") - cm.triggerReconnect() - - case <-signalerReady: - log.Debugf("Signaler became ready, triggering reconnect") - cm.triggerReconnect() - - case <-localCandidatesChanged: - log.Debugf("Local candidates changed, triggering reconnect") - cm.triggerReconnect() - - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorSignalerReady(ctx context.Context, signalerReady chan<- struct{}) { - if cm.signaler == nil { - return - } - - ticker := time.NewTicker(signalerMonitorPeriod) - defer ticker.Stop() - - lastReady := true - for { - select { - case <-ticker.C: - currentReady := cm.signaler.Ready() - if !lastReady && currentReady { - select { - case signalerReady <- struct{}{}: - default: - } - } - lastReady = currentReady - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) monitorLocalCandidatesChanged(ctx context.Context, localCandidatesChanged chan<- struct{}) { - ufrag, pwd, err := generateICECredentials() - if err != nil { - log.Warnf("Failed to generate ICE credentials: %v", err) - return - } - - ticker := time.NewTicker(candidatesMonitorPeriod) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := cm.handleCandidateTick(ctx, localCandidatesChanged, ufrag, pwd); err != nil { - log.Warnf("Failed to handle candidate tick: %v", err) - } - case <-ctx.Done(): - return - } - } -} - -func (cm *ConnMonitor) handleCandidateTick(ctx context.Context, localCandidatesChanged chan<- struct{}, ufrag string, pwd string) error { - log.Debugf("Gathering ICE candidates") - - transportNet, err := newStdNet(cm.iFaceDiscover, cm.config.ICEConfig.InterfaceBlackList) - if err != nil { - log.Errorf("failed to create pion's stdnet: %s", err) - } - - agent, err := newAgent(cm.config, transportNet, candidateTypesP2P(), ufrag, pwd) - if err != nil { - return fmt.Errorf("create ICE agent: %w", err) - } - defer func() { - if err := agent.Close(); err != nil { - log.Warnf("Failed to close ICE agent: %v", err) - } - }() - - gatherDone := make(chan struct{}) - err = agent.OnCandidate(func(c ice.Candidate) { - log.Tracef("Got candidate: %v", c) - if c == nil { - close(gatherDone) - } - }) - if err != nil { - return fmt.Errorf("set ICE candidate handler: %w", err) - } - - if err := agent.GatherCandidates(); err != nil { - return fmt.Errorf("gather ICE candidates: %w", err) - } - - ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) - defer cancel() - - select { - case <-ctx.Done(): - return fmt.Errorf("wait for gathering: %w", ctx.Err()) - case <-gatherDone: - } - - candidates, err := agent.GetLocalCandidates() - if err != nil { - return fmt.Errorf("get local candidates: %w", err) - } - log.Tracef("Got candidates: %v", candidates) - - if changed := cm.updateCandidates(candidates); changed { - select { - case localCandidatesChanged <- struct{}{}: - default: - } - } - - return nil -} - -func (cm *ConnMonitor) updateCandidates(newCandidates []ice.Candidate) bool { - cm.candidatesMu.Lock() - defer cm.candidatesMu.Unlock() - - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates - return true - } - - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } - } - - return false -} - -func (cm *ConnMonitor) triggerReconnect() { - select { - case cm.reconnectCh <- struct{}{}: - default: - } -} diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index e68861c5f04..039952588d8 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -10,6 +10,8 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/peer/guard" + "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" ) @@ -19,7 +21,7 @@ var connConf = ConnConfig{ LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", Timeout: time.Second, LocalWgPort: 51820, - ICEConfig: ICEConfig{ + ICEConfig: ice.Config{ InterfaceBlackList: nil, }, } @@ -43,7 +45,8 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) if err != nil { return } @@ -54,7 +57,8 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } @@ -87,7 +91,8 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } @@ -119,7 +124,8 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil) + swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) if err != nil { return } diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go new file mode 100644 index 00000000000..bf3527a6264 --- /dev/null +++ b/client/internal/peer/guard/guard.go @@ -0,0 +1,194 @@ +package guard + +import ( + "context" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" +) + +const ( + reconnectMaxElapsedTime = 30 * time.Minute +) + +type isConnectedFunc func() bool + +// Guard is responsible for the reconnection logic. +// It will trigger to send an offer to the peer then has connection issues. +// Watch these events: +// - Relay client reconnected to home server +// - Signal server connection state changed +// - ICE connection disconnected +// - Relayed connection disconnected +// - ICE candidate changes +type Guard struct { + Reconnect chan struct{} + log *log.Entry + isController bool + isConnectedOnAllWay isConnectedFunc + timeout time.Duration + srWatcher *SRWatcher + relayedConnDisconnected chan bool + iCEConnDisconnected chan bool +} + +func NewGuard(log *log.Entry, isController bool, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { + return &Guard{ + Reconnect: make(chan struct{}, 1), + log: log, + isController: isController, + isConnectedOnAllWay: isConnectedFn, + timeout: timeout, + srWatcher: srWatcher, + relayedConnDisconnected: make(chan bool, 1), + iCEConnDisconnected: make(chan bool, 1), + } +} + +func (g *Guard) Start(ctx context.Context) { + if g.isController { + g.reconnectLoopWithRetry(ctx) + } else { + g.listenForDisconnectEvents(ctx) + } +} + +func (g *Guard) SetRelayedConnDisconnected(changed bool) { + select { + case g.relayedConnDisconnected <- changed: + default: + } +} + +func (g *Guard) SetICEConnDisconnected(changed bool) { + select { + case g.iCEConnDisconnected <- changed: + default: + } +} + +// reconnectLoopWithRetry periodically check (max 30 min) the connection status. +// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +func (g *Guard) reconnectLoopWithRetry(ctx context.Context) { + waitForInitialConnectionTry(ctx) + + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + ticker := g.prepareExponentTicker(ctx) + defer ticker.Stop() + + tickerChannel := ticker.C + + g.log.Infof("start reconnect loop...") + for { + select { + case t := <-tickerChannel: + if t.IsZero() { + g.log.Infof("retry timed out, stop periodic offer sending") + // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop + tickerChannel = make(<-chan time.Time) + continue + } + + if !g.isConnectedOnAllWay() { + g.triggerOfferSending() + } + + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE connection changed, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case <-srReconnectedChan: + g.log.Debugf("has network changes, reset reconnection ticker") + ticker.Stop() + ticker = g.prepareExponentTicker(ctx) + tickerChannel = ticker.C + + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +// listenForDisconnectEvents is used when the peer is not a controller and it should reconnect to the peer +// when the connection is lost. It will try to establish a connection only once time if before the connection was established +// It track separately the ice and relay connection status. Just because a lower priority connection reestablished it does not +// mean that to switch to it. We always force to use the higher priority connection. +func (g *Guard) listenForDisconnectEvents(ctx context.Context) { + srReconnectedChan := g.srWatcher.NewListener() + defer g.srWatcher.RemoveListener(srReconnectedChan) + + g.log.Infof("start listen for reconnect events...") + for { + select { + case changed := <-g.relayedConnDisconnected: + if !changed { + continue + } + g.log.Debugf("Relay connection changed, triggering reconnect") + g.triggerOfferSending() + case changed := <-g.iCEConnDisconnected: + if !changed { + continue + } + g.log.Debugf("ICE state changed, try to send new offer") + g.triggerOfferSending() + case <-srReconnectedChan: + g.triggerOfferSending() + case <-ctx.Done(): + g.log.Debugf("context is done, stop reconnect loop") + return + } + } +} + +func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 0.1, + Multiplier: 2, + MaxInterval: g.timeout, + MaxElapsedTime: reconnectMaxElapsedTime, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + + ticker := backoff.NewTicker(bo) + <-ticker.C // consume the initial tick what is happening right after the ticker has been created + + return ticker +} + +func (g *Guard) triggerOfferSending() { + select { + case g.Reconnect <- struct{}{}: + default: + } +} + +// Give chance to the peer to establish the initial connection. +// With it, we can decrease to send necessary offer +func waitForInitialConnectionTry(ctx context.Context) { + select { + case <-ctx.Done(): + return + case <-time.After(3 * time.Second): + } +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go new file mode 100644 index 00000000000..b9c9aa1345c --- /dev/null +++ b/client/internal/peer/guard/ice_monitor.go @@ -0,0 +1,135 @@ +package guard + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/pion/ice/v3" + log "github.com/sirupsen/logrus" + + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +const ( + candidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second +) + +type ICEMonitor struct { + ReconnectCh chan struct{} + + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig icemaker.Config + + currentCandidates []ice.Candidate + candidatesMu sync.Mutex +} + +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { + cm := &ICEMonitor{ + ReconnectCh: make(chan struct{}, 1), + iFaceDiscover: iFaceDiscover, + iceConfig: config, + } + return cm +} + +func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { + ufrag, pwd, err := icemaker.GenerateICECredentials() + if err != nil { + log.Warnf("Failed to generate ICE credentials: %v", err) + return + } + + ticker := time.NewTicker(candidatesMonitorPeriod) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + changed, err := cm.handleCandidateTick(ctx, ufrag, pwd) + if err != nil { + log.Warnf("Failed to check ICE changes: %v", err) + continue + } + + if changed { + onChanged() + } + case <-ctx.Done(): + return + } + } +} + +func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { + log.Debugf("Gathering ICE candidates") + + agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + if err != nil { + return false, fmt.Errorf("create ICE agent: %w", err) + } + defer func() { + if err := agent.Close(); err != nil { + log.Warnf("Failed to close ICE agent: %v", err) + } + }() + + gatherDone := make(chan struct{}) + err = agent.OnCandidate(func(c ice.Candidate) { + log.Tracef("Got candidate: %v", c) + if c == nil { + close(gatherDone) + } + }) + if err != nil { + return false, fmt.Errorf("set ICE candidate handler: %w", err) + } + + if err := agent.GatherCandidates(); err != nil { + return false, fmt.Errorf("gather ICE candidates: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, candidateGatheringTimeout) + defer cancel() + + select { + case <-ctx.Done(): + return false, fmt.Errorf("wait for gathering timed out") + case <-gatherDone: + } + + candidates, err := agent.GetLocalCandidates() + if err != nil { + return false, fmt.Errorf("get local candidates: %w", err) + } + log.Tracef("Got candidates: %v", candidates) + + return cm.updateCandidates(candidates), nil +} + +func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { + cm.candidatesMu.Lock() + defer cm.candidatesMu.Unlock() + + if len(cm.currentCandidates) != len(newCandidates) { + cm.currentCandidates = newCandidates + return true + } + + for i, candidate := range cm.currentCandidates { + if candidate.Address() != newCandidates[i].Address() { + cm.currentCandidates = newCandidates + return true + } + } + + return false +} + +func candidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go new file mode 100644 index 00000000000..90e45426f78 --- /dev/null +++ b/client/internal/peer/guard/sr_watcher.go @@ -0,0 +1,119 @@ +package guard + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +type chNotifier interface { + SetOnReconnectedListener(func()) + Ready() bool +} + +type SRWatcher struct { + signalClient chNotifier + relayManager chNotifier + + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config + + cancelIceMonitor context.CancelFunc +} + +// NewSRWatcher creates a new SRWatcher. This watcher will notify the listeners when the ICE candidates change or the +// Relay connection is reconnected or the Signal client reconnected. +func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscover stdnet.ExternalIFaceDiscover, iceConfig ice.Config) *SRWatcher { + srw := &SRWatcher{ + signalClient: signalClient, + relayManager: relayManager, + iFaceDiscover: iFaceDiscover, + iceConfig: iceConfig, + listeners: make(map[chan struct{}]struct{}), + } + return srw +} + +func (w *SRWatcher) Start() { + w.mu.Lock() + defer w.mu.Unlock() + + if w.cancelIceMonitor != nil { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + w.cancelIceMonitor = cancel + + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + go iceMonitor.Start(ctx, w.onICEChanged) + w.signalClient.SetOnReconnectedListener(w.onReconnected) + w.relayManager.SetOnReconnectedListener(w.onReconnected) + +} + +func (w *SRWatcher) Close() { + w.mu.Lock() + defer w.mu.Unlock() + + if w.cancelIceMonitor == nil { + return + } + w.cancelIceMonitor() + w.signalClient.SetOnReconnectedListener(nil) + w.relayManager.SetOnReconnectedListener(nil) +} + +func (w *SRWatcher) NewListener() chan struct{} { + w.mu.Lock() + defer w.mu.Unlock() + + listenerChan := make(chan struct{}, 1) + w.listeners[listenerChan] = struct{}{} + return listenerChan +} + +func (w *SRWatcher) RemoveListener(listenerChan chan struct{}) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.listeners, listenerChan) + close(listenerChan) +} + +func (w *SRWatcher) onICEChanged() { + if !w.signalClient.Ready() { + return + } + + log.Infof("network changes detected by ICE agent") + w.notify() +} + +func (w *SRWatcher) onReconnected() { + if !w.signalClient.Ready() { + return + } + if !w.relayManager.Ready() { + return + } + + log.Infof("reconnected to Signal or Relay server") + w.notify() +} + +func (w *SRWatcher) notify() { + w.mu.Lock() + defer w.mu.Unlock() + for listener := range w.listeners { + select { + case listener <- struct{}{}: + default: + } + } +} diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go new file mode 100644 index 00000000000..b2a9669367e --- /dev/null +++ b/client/internal/peer/ice/agent.go @@ -0,0 +1,89 @@ +package ice + +import ( + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/pion/ice/v3" + "github.com/pion/randutil" + "github.com/pion/stun/v2" + log "github.com/sirupsen/logrus" + "runtime" + "time" +) + +const ( + lenUFrag = 16 + lenPwd = 32 + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + iceKeepAliveDefault = 4 * time.Second + iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second +) + +var ( + failedTimeout = 6 * time.Second +) + +func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + if err != nil { + log.Errorf("failed to create pion's stdnet: %s", err) + } + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: config.StunTurn.Load().([]*stun.URI), + CandidateTypes: candidateTypes, + InterfaceFilter: stdnet.InterfaceFilter(config.InterfaceBlackList), + UDPMux: config.UDPMux, + UDPMuxSrflx: config.UDPMuxSrflx, + NAT1To1IPs: config.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: ufrag, + LocalPwd: pwd, + } + + if config.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + return ice.NewAgent(agentConfig) +} + +func GenerateICECredentials() (string, string, error) { + ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) + if err != nil { + return "", "", err + } + + pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) + if err != nil { + return "", "", err + } + return ufrag, pwd, nil +} + +func CandidateTypes() []ice.CandidateType { + if hasICEForceRelayConn() { + return []ice.CandidateType{ice.CandidateTypeRelay} + } + // TODO: remove this once we have refactored userspace proxy into the bind package + if runtime.GOOS == "ios" { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} + } + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} +} + +func CandidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} diff --git a/client/internal/peer/ice/config.go b/client/internal/peer/ice/config.go new file mode 100644 index 00000000000..8abc842f0d2 --- /dev/null +++ b/client/internal/peer/ice/config.go @@ -0,0 +1,22 @@ +package ice + +import ( + "sync/atomic" + + "github.com/pion/ice/v3" +) + +type Config struct { + // StunTurn is a list of STUN and TURN URLs + StunTurn *atomic.Value // []*stun.URI + + // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering + // (e.g. if eth0 is in the list, host candidate of this interface won't be used) + InterfaceBlackList []string + DisableIPv6Discovery bool + + UDPMux ice.UDPMux + UDPMuxSrflx ice.UniversalUDPMux + + NATExternalIPs []string +} diff --git a/client/internal/peer/env_config.go b/client/internal/peer/ice/env.go similarity index 80% rename from client/internal/peer/env_config.go rename to client/internal/peer/ice/env.go index 87b626df763..3b0cb74ad2a 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/ice/env.go @@ -1,4 +1,4 @@ -package peer +package ice import ( "os" @@ -10,12 +10,19 @@ import ( ) const ( + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" + + msgWarnInvalidValue = "invalid value %s set for %s, using default %v" ) +func hasICEForceRelayConn() bool { + disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) + return strings.ToLower(disconnectedTimeoutEnv) == "true" +} + func iceKeepAlive() time.Duration { keepAliveEnv := os.Getenv(envICEKeepAliveIntervalSec) if keepAliveEnv == "" { @@ -25,7 +32,7 @@ func iceKeepAlive() time.Duration { log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) + log.Warnf(msgWarnInvalidValue, keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) return iceKeepAliveDefault } @@ -41,7 +48,7 @@ func iceDisconnectedTimeout() time.Duration { log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) + log.Warnf(msgWarnInvalidValue, disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) return iceDisconnectedTimeoutDefault } @@ -57,14 +64,9 @@ func iceRelayAcceptanceMinWait() time.Duration { log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) if err != nil { - log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) + log.Warnf(msgWarnInvalidValue, iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) return iceRelayAcceptanceMinWaitDefault } return time.Duration(disconnectedTimeoutSec) * time.Second } - -func hasICEForceRelayConn() bool { - disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) - return strings.ToLower(disconnectedTimeoutEnv) == "true" -} diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/ice/stdnet.go similarity index 94% rename from client/internal/peer/stdnet.go rename to client/internal/peer/ice/stdnet.go index 96d211dbc77..3ce83727e6e 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -1,6 +1,6 @@ //go:build !android -package peer +package ice import ( "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go similarity index 94% rename from client/internal/peer/stdnet_android.go rename to client/internal/peer/ice/stdnet_android.go index a39a03b1c83..84c665e6f40 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,4 +1,4 @@ -package peer +package ice import "github.com/netbirdio/netbird/client/internal/stdnet" diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index c86c1858fdc..55894218d73 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,52 +5,20 @@ import ( "fmt" "net" "net/netip" - "runtime" "sync" - "sync/atomic" "time" "github.com/pion/ice/v3" - "github.com/pion/randutil" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/route" ) -const ( - iceKeepAliveDefault = 4 * time.Second - iceDisconnectedTimeoutDefault = 6 * time.Second - // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package - iceRelayAcceptanceMinWaitDefault = 2 * time.Second - - lenUFrag = 16 - lenPwd = 32 - runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -) - -var ( - failedTimeout = 6 * time.Second -) - -type ICEConfig struct { - // StunTurn is a list of STUN and TURN URLs - StunTurn *atomic.Value // []*stun.URI - - // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering - // (e.g. if eth0 is in the list, host candidate of this interface won't be used) - InterfaceBlackList []string - DisableIPv6Discovery bool - - UDPMux ice.UDPMux - UDPMuxSrflx ice.UniversalUDPMux - - NATExternalIPs []string -} - type ICEConnInfo struct { RemoteConn net.Conn RosenpassPubKey []byte @@ -103,7 +71,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signal conn: callBacks, } - localUfrag, localPwd, err := generateICECredentials() + localUfrag, localPwd, err := icemaker.GenerateICECredentials() if err != nil { return nil, err } @@ -125,10 +93,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { w.selectedPriority = connPriorityICEP2P - preferredCandidateTypes = candidateTypesP2P() + preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { w.selectedPriority = connPriorityICETurn - preferredCandidateTypes = candidateTypes() + preferredCandidateTypes = icemaker.CandidateTypes() } w.log.Debugf("recreate ICE agent") @@ -232,15 +200,10 @@ func (w *WorkerICE) Close() { } } -func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { - transportNet, err := newStdNet(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) - if err != nil { - w.log.Errorf("failed to create pion's stdnet: %s", err) - } - +func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) { w.sentExtraSrflx = false - agent, err := newAgent(w.config, transportNet, relaySupport, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } @@ -365,36 +328,6 @@ func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferA } } -func newAgent(config ConnConfig, transportNet *stdnet.Net, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ice.Agent, error) { - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: config.ICEConfig.StunTurn.Load().([]*stun.URI), - CandidateTypes: candidateTypes, - InterfaceFilter: stdnet.InterfaceFilter(config.ICEConfig.InterfaceBlackList), - UDPMux: config.ICEConfig.UDPMux, - UDPMuxSrflx: config.ICEConfig.UDPMuxSrflx, - NAT1To1IPs: config.ICEConfig.NATExternalIPs, - Net: transportNet, - FailedTimeout: &failedTimeout, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - LocalUfrag: ufrag, - LocalPwd: pwd, - } - - if config.ICEConfig.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - - return ice.NewAgent(agentConfig) -} - func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -435,21 +368,6 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool return false } -func candidateTypes() []ice.CandidateType { - if hasICEForceRelayConn() { - return []ice.CandidateType{ice.CandidateTypeRelay} - } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} -} - -func candidateTypesP2P() []ice.CandidateType { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} -} - func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } @@ -460,16 +378,3 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } - -func generateICECredentials() (string, string, error) { - ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) - if err != nil { - return "", "", err - } - - pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) - if err != nil { - return "", "", err - } - return ufrag, pwd, nil -} diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c02fccebc47..c22dcdeda5d 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -31,6 +31,7 @@ type WorkerRelayCallbacks struct { type WorkerRelay struct { log *log.Entry + isController bool config ConnConfig relayManager relayClient.ManagerService callBacks WorkerRelayCallbacks @@ -44,9 +45,10 @@ type WorkerRelay struct { relaySupportedOnRemotePeer atomic.Bool } -func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { +func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { r := &WorkerRelay{ log: log, + isController: ctrl, config: config, relayManager: relayManager, callBacks: callbacks, @@ -80,6 +82,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Errorf("failed to open connection via Relay: %s", err) return } + w.relayLock.Lock() w.relayedConn = relayedConn w.relayLock.Unlock() @@ -136,10 +139,6 @@ func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool { return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally() } -func (w *WorkerRelay) IsController() bool { - return w.config.LocalKey > w.config.Key -} - func (w *WorkerRelay) RelayIsSupportedLocally() bool { return w.relayManager.HasRelayAddress() } @@ -212,7 +211,7 @@ func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { } func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { - if w.IsController() { + if w.isController { return myRelayAddress } return remoteRelayAddress diff --git a/relay/client/client.go b/relay/client/client.go index 20a73f4b343..a82a75453bf 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -142,6 +142,7 @@ type Client struct { muInstanceURL sync.Mutex onDisconnectListener func() + onConnectedListener func() listenerMutex sync.Mutex } @@ -191,6 +192,7 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) + go c.notifyConnected() return nil } @@ -238,6 +240,12 @@ func (c *Client) SetOnDisconnectListener(fn func()) { c.onDisconnectListener = fn } +func (c *Client) SetOnConnectedListener(fn func()) { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + c.onConnectedListener = fn +} + // HasConns returns true if there are connections. func (c *Client) HasConns() bool { c.mu.Lock() @@ -245,6 +253,12 @@ func (c *Client) HasConns() bool { return len(c.conns) > 0 } +func (c *Client) Ready() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.serviceIsRunning +} + // Close closes the connection to the relay server and all connections to other peers. func (c *Client) Close() error { return c.close(true) @@ -363,9 +377,9 @@ func (c *Client) readLoop(relayConn net.Conn) { c.instanceURL = nil c.muInstanceURL.Unlock() - c.notifyDisconnected() c.wgReadLoop.Done() _ = c.close(false) + c.notifyDisconnected() } func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { @@ -544,6 +558,16 @@ func (c *Client) notifyDisconnected() { go c.onDisconnectListener() } +func (c *Client) notifyConnected() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onConnectedListener == nil { + return + } + go c.onConnectedListener() +} + func (c *Client) writeCloseMsg() { msg := messages.MarshalCloseMsg() _, err := c.relayConn.Write(msg) diff --git a/relay/client/guard.go b/relay/client/guard.go index f826cf1b600..d6b6b0da509 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -29,6 +29,10 @@ func NewGuard(context context.Context, relayClient *Client) *Guard { // OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent func (g *Guard) OnDisconnected() { + if g.quickReconnect() { + return + } + ticker := time.NewTicker(reconnectingTimeout) defer ticker.Stop() @@ -46,3 +50,19 @@ func (g *Guard) OnDisconnected() { } } } + +func (g *Guard) quickReconnect() bool { + ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) + defer cancel() + <-ctx.Done() + + if g.ctx.Err() != nil { + return false + } + + if err := g.relayClient.Connect(); err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + return false + } + return true +} diff --git a/relay/client/manager.go b/relay/client/manager.go index 4554c7c0f6e..3981415fcd4 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -65,6 +65,7 @@ type Manager struct { relayClientsMutex sync.RWMutex onDisconnectedListeners map[string]*list.List + onReconnectedListenerFn func() listenerLock sync.Mutex } @@ -101,6 +102,7 @@ func (m *Manager) Serve() error { m.relayClient = client m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnConnectedListener(m.onServerConnected) m.relayClient.SetOnDisconnectListener(func() { m.onServerDisconnected(client.connectionURL) }) @@ -138,6 +140,18 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { return netConn, err } +// Ready returns true if the home Relay client is connected to the relay server. +func (m *Manager) Ready() bool { + if m.relayClient == nil { + return false + } + return m.relayClient.Ready() +} + +func (m *Manager) SetOnReconnectedListener(f func()) { + m.onReconnectedListenerFn = f +} + // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { @@ -240,6 +254,13 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return conn, nil } +func (m *Manager) onServerConnected() { + if m.onReconnectedListenerFn == nil { + return + } + go m.onReconnectedListenerFn() +} + func (m *Manager) onServerDisconnected(serverAddress string) { if serverAddress == m.relayClient.connectionURL { go m.reconnectGuard.OnDisconnected() diff --git a/signal/client/client.go b/signal/client/client.go index ced3fb7d0eb..eff1ccb8794 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -35,6 +35,7 @@ type Client interface { WaitStreamConnected() SendToStream(msg *proto.EncryptedMessage) error Send(msg *proto.Message) error + SetOnReconnectedListener(func()) } // UnMarshalCredential parses the credentials from the message and returns a Credential instance diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 7a3b502ffc6..2ff84e46075 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -43,6 +43,8 @@ type GrpcClient struct { connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + + onReconnectedListenerFn func() } func (c *GrpcClient) StreamConnected() bool { @@ -181,12 +183,17 @@ func (c *GrpcClient) notifyStreamDisconnected() { func (c *GrpcClient) notifyStreamConnected() { c.mux.Lock() defer c.mux.Unlock() + c.status = StreamConnected if c.connectedCh != nil { // there are goroutines waiting on this channel -> release them close(c.connectedCh) c.connectedCh = nil } + + if c.onReconnectedListenerFn != nil { + c.onReconnectedListenerFn() + } } func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { @@ -271,6 +278,13 @@ func (c *GrpcClient) WaitStreamConnected() { } } +func (c *GrpcClient) SetOnReconnectedListener(fn func()) { + c.mux.Lock() + defer c.mux.Unlock() + + c.onReconnectedListenerFn = fn +} + // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // GrpcClient.connWg can be used to wait diff --git a/signal/client/mock.go b/signal/client/mock.go index 70ecea9eda2..32236c82c09 100644 --- a/signal/client/mock.go +++ b/signal/client/mock.go @@ -7,14 +7,20 @@ import ( ) type MockClient struct { - CloseFunc func() error - GetStatusFunc func() Status - StreamConnectedFunc func() bool - ReadyFunc func() bool - WaitStreamConnectedFunc func() - ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error - SendToStreamFunc func(msg *proto.EncryptedMessage) error - SendFunc func(msg *proto.Message) error + CloseFunc func() error + GetStatusFunc func() Status + StreamConnectedFunc func() bool + ReadyFunc func() bool + WaitStreamConnectedFunc func() + ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error + SendToStreamFunc func(msg *proto.EncryptedMessage) error + SendFunc func(msg *proto.Message) error + SetOnReconnectedListenerFunc func(f func()) +} + +// SetOnReconnectedListener sets the function to be called when the client reconnects. +func (sm *MockClient) SetOnReconnectedListener(_ func()) { + // Do nothing } func (sm *MockClient) IsHealthy() bool { From 8016710d241efb2b8dee03ff317128d3cec198b8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:46:24 +0200 Subject: [PATCH 60/81] [client] Cleanup firewall state on startup (#2768) --- client/firewall/create.go | 4 +- client/firewall/create_linux.go | 82 +++++++----- client/firewall/iptables/acl_linux.go | 79 ++++++++--- client/firewall/iptables/manager_linux.go | 71 +++++++--- .../firewall/iptables/manager_linux_test.go | 20 +-- client/firewall/iptables/router_linux.go | 93 ++++++++++--- client/firewall/iptables/router_linux_test.go | 13 +- client/firewall/iptables/rulestore_linux.go | 57 +++++++- client/firewall/iptables/state_linux.go | 70 ++++++++++ client/firewall/manager/firewall.go | 6 +- client/firewall/nftables/acl_linux.go | 24 +--- client/firewall/nftables/manager_linux.go | 125 ++++++++++++++---- .../firewall/nftables/manager_linux_test.go | 13 +- client/firewall/nftables/router_linux.go | 28 ++-- client/firewall/nftables/router_linux_test.go | 13 +- client/firewall/nftables/state_linux.go | 47 +++++++ client/firewall/uspfilter/allow_netbird.go | 6 +- .../uspfilter/allow_netbird_windows.go | 4 +- client/firewall/uspfilter/uspfilter.go | 7 +- client/firewall/uspfilter/uspfilter_test.go | 6 +- client/internal/acl/manager_test.go | 9 +- client/internal/connect.go | 11 +- client/internal/dns/server.go | 7 + client/internal/engine.go | 5 +- .../routemanager/refcounter/refcounter.go | 40 ++++++ .../internal/routemanager/systemops/state.go | 67 ++-------- .../systemops/systemops_generic.go | 22 +-- client/internal/statemanager/path.go | 4 +- client/server/server.go | 33 ----- client/server/state.go | 37 ++++++ client/server/state_generic.go | 14 ++ client/server/state_linux.go | 18 +++ 32 files changed, 728 insertions(+), 307 deletions(-) create mode 100644 client/firewall/iptables/state_linux.go create mode 100644 client/firewall/nftables/state_linux.go create mode 100644 client/server/state.go create mode 100644 client/server/state_generic.go create mode 100644 client/server/state_linux.go diff --git a/client/firewall/create.go b/client/firewall/create.go index 86ce94ceabb..9466f4b4d6b 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "runtime" @@ -11,10 +10,11 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // NewFirewall creates a firewall manager instance -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 92deb63dc86..c853548f841 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -3,7 +3,6 @@ package firewall import ( - "context" "fmt" "os" @@ -15,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -32,54 +32,72 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - var fm firewall.Manager - var errFw error + fm, errFw := createNativeFirewall(iface) + if fm != nil { + if err := fm.Init(stateManager); err != nil { + log.Errorf("failed to init nftables manager: %s", err) + } + } + + if iface.IsUserspaceBind() { + return createUserspaceFirewall(iface, fm, errFw) + } + + return fm, errFw +} + +func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) { switch check() { case IPTABLES: - log.Info("creating an iptables firewall manager") - fm, errFw = nbiptables.Create(context, iface) - if errFw != nil { - log.Errorf("failed to create iptables manager: %s", errFw) - } + return createIptablesFirewall(iface) case NFTABLES: - log.Info("creating an nftables firewall manager") - fm, errFw = nbnftables.Create(context, iface) - if errFw != nil { - log.Errorf("failed to create nftables manager: %s", errFw) - } + return createNftablesFirewall(iface) default: - errFw = fmt.Errorf("no firewall manager found") log.Info("no firewall manager found, trying to use userspace packet filtering firewall") + return nil, fmt.Errorf("no firewall manager found") } +} - if iface.IsUserspaceBind() { - var errUsp error - if errFw == nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) - } else { - fm, errUsp = uspfilter.Create(iface) - } - if errUsp != nil { - log.Debugf("failed to create userspace filtering firewall: %s", errUsp) - return nil, errUsp - } +func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) { + log.Info("creating an iptables firewall manager") + fm, err := nbiptables.Create(iface) + if err != nil { + log.Errorf("failed to create iptables manager: %s", err) + } + return fm, err +} - if err := fm.AllowNetbird(); err != nil { - log.Errorf("failed to allow netbird interface traffic: %v", err) - } - return fm, nil +func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) { + log.Info("creating an nftables firewall manager") + fm, err := nbnftables.Create(iface) + if err != nil { + log.Errorf("failed to create nftables manager: %s", err) } + return fm, err +} - if errFw != nil { - return nil, errFw +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) { + var errUsp error + if errFw == nil { + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + } else { + fm, errUsp = uspfilter.Create(iface) } + if errUsp != nil { + log.Debugf("failed to create userspace filtering firewall: %s", errUsp) + return nil, errUsp + } + + if err := fm.AllowNetbird(); err != nil { + log.Errorf("failed to allow netbird interface traffic: %v", err) + } return fm, nil } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index c271e592dce..5cd69245b65 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -22,6 +23,8 @@ const ( chainNameOutputRules = "NETBIRD-ACL-OUTPUT" ) +type aclEntries map[string][][]string + type entry struct { spec []string position int @@ -32,9 +35,11 @@ type aclManager struct { wgIface iFaceMapper routingFwChainName string - entries map[string][][]string + entries aclEntries optionalEntries map[string][]entry ipsetStore *ipsetStore + + stateManager *statemanager.Manager } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { @@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi ipsetStore: newIpsetStore(), } - err := ipset.Init() - if err != nil { - return nil, fmt.Errorf("failed to init ipset: %w", err) + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) } + return m, nil +} + +func (m *aclManager) init(stateManager *statemanager.Manager) error { + m.stateManager = stateManager + m.seedInitialEntries() m.seedInitialOptionalEntries() - err = m.cleanChains() - if err != nil { - return nil, err + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) } - err = m.createDefaultChains() - if err != nil { - return nil, err + if err := m.createDefaultChains(); err != nil { + return fmt.Errorf("create default chains: %w", err) } - return m, nil + + m.updateState() + + return nil } func (m *aclManager) AddPeerFiltering( @@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering( chain: chain, } + m.updateState() + return []firewall.Rule{rule}, nil } @@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { } } - err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) - if err != nil { - log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) + if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { + return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) } - return err + + m.updateState() + + return nil } func (m *aclManager) Reset() error { - return m.cleanChains() + if err := m.cleanChains(); err != nil { + return fmt.Errorf("clean chains: %w", err) + } + + m.updateState() + + return nil } // todo write less destructive cleanup mechanism @@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) { m.entries[chainName] = append(m.entries[chainName], spec) } +func (m *aclManager) updateState() { + if m.stateManager == nil { + return + } + + var currentState *ShutdownState + if existing := m.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + currentState = existingState + } + } + if currentState == nil { + currentState = &ShutdownState{} + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.ACLEntries = m.entries + currentState.ACLIPsetStore = m.ipsetStore + + if err := m.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + // filterRuleSpecs returns the specs of a filtering rule func filterRuleSpecs( ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 94bd2fccfe1..a59bd2c602e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -8,10 +8,13 @@ import ( "sync" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) // Manager of iptables firewall @@ -33,10 +36,10 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { - return nil, fmt.Errorf("iptables is not installed in the system or not supported") + return nil, fmt.Errorf("init iptables: %w", err) } m := &Manager{ @@ -44,20 +47,49 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(context, iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface) if err != nil { - log.Debugf("failed to initialize route related chains: %s", err) - return nil, err + return nil, fmt.Errorf("create router: %w", err) } + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { - log.Debugf("failed to initialize ACL manager: %s", err) - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +func (m *Manager) Init(stateManager *statemanager.Manager) error { + state := &ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + } + stateManager.RegisterState(state) + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } + + if err := m.router.init(stateManager); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclMgr.init(stateManager); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + // persist early to ensure cleanup of chains + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported @@ -133,20 +165,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - errAcl := m.aclMgr.Reset() - if errAcl != nil { - log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) + var merr *multierror.Error + + if err := m.aclMgr.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) + } + if err := m.router.Reset(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } - errMgr := m.router.Reset() - if errMgr != nil { - log.Errorf("failed to clean up router rules from firewall: %s", errMgr) - return errMgr + + // attempt to delete state only if all other operations succeeded + if merr == nil { + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err)) + } } - return errAcl + + return nberrors.FormatErrorOrNil(merr) } // AllowNetbird allows netbird interface traffic diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 498d8f58b09..ebdb831376f 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -1,7 +1,6 @@ package iptables import ( - "context" "fmt" "net" "testing" @@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) { _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) @@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) @@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) { }) t.Run("reset check", func(t *testing.T) { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") }) } @@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) defer func() { - err := manager.Reset() + err := manager.Reset(nil) require.NoError(t, err, "clear the manager state") time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 12932392871..90811ae1182 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "fmt" "net/netip" "strconv" @@ -18,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -48,28 +48,31 @@ type routeFilteringRuleParams struct { SetName string } +type routeRules map[string][]string + +type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] + type router struct { - ctx context.Context - stop context.CancelFunc iptablesClient *iptables.IPTables - rules map[string][]string - ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + rules routeRules + ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + + stateManager *statemanager.Manager } -func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, } r.ipsetCounter = refcounter.New( - r.createIpSet, + func(name string, sources []netip.Prefix) (struct{}, error) { + return struct{}{}, r.createIpSet(name, sources) + }, func(name string, _ struct{}) error { return r.deleteIpSet(name) }, @@ -79,16 +82,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI return nil, fmt.Errorf("init ipset: %w", err) } - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("cleanup routing rules: %s", err) - return nil, err + return r, nil +} + +func (r *router) init(stateManager *statemanager.Manager) error { + r.stateManager = stateManager + + if err := r.cleanUpDefaultForwardRules(); err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("create containers for route: %s", err) + + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + r.updateState() + + return nil } func (r *router) AddRouteFiltering( @@ -129,6 +139,8 @@ func (r *router) AddRouteFiltering( r.rules[string(ruleKey)] = rule + r.updateState() + return ruleKey, nil } @@ -152,6 +164,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error { log.Debugf("route rule %s not found", ruleKey) } + r.updateState() + return nil } @@ -164,18 +178,18 @@ func (r *router) findSetNameInRule(rule []string) string { return "" } -func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { +func (r *router) createIpSet(setName string, sources []netip.Prefix) error { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { - return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + return fmt.Errorf("create set %s: %w", setName, err) } for _, prefix := range sources { if err := ipset.AddPrefix(setName, prefix); err != nil { - return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + return fmt.Errorf("add element to set %s: %w", setName, err) } } - return struct{}{}, nil + return nil } func (r *router) deleteIpSet(setName string) error { @@ -206,6 +220,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error { return fmt.Errorf("add inverse nat rule: %w", err) } + r.updateState() + return nil } @@ -223,6 +239,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error { return fmt.Errorf("remove legacy routing rule: %w", err) } + r.updateState() + return nil } @@ -280,6 +298,9 @@ func (r *router) RemoveAllLegacyRouteRules() error { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } } + + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -294,6 +315,8 @@ func (r *router) Reset() error { merr = multierror.Append(merr, err) } + r.updateState() + return nberrors.FormatErrorOrNil(merr) } @@ -431,6 +454,32 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error { return nil } +func (r *router) updateState() { + if r.stateManager == nil { + return + } + + var currentState *ShutdownState + if existing := r.stateManager.GetState(currentState); existing != nil { + if existingState, ok := existing.(*ShutdownState); ok { + currentState = existingState + } + } + if currentState == nil { + currentState = &ShutdownState{} + } + + currentState.Lock() + defer currentState.Unlock() + + currentState.RouteRules = r.rules + currentState.RouteIPsetCounter = r.ipsetCounter + + if err := r.stateManager.UpdateState(currentState); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { intdir := "-i" lointdir := "-o" diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 6cede09e2b9..2d821a9db7f 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -3,7 +3,6 @@ package iptables import ( - "context" "net/netip" "os/exec" "testing" @@ -30,8 +29,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() @@ -74,8 +74,9 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { err := manager.Reset() @@ -132,8 +133,9 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") + require.NoError(t, manager.init(nil)) defer func() { _ = manager.Reset() }() @@ -183,8 +185,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock) require.NoError(t, err, "Failed to create router manager") + require.NoError(t, r.init(nil)) defer func() { err := r.Reset() diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index a9470c9ac72..bfd08bee27d 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -1,14 +1,16 @@ package iptables +import "encoding/json" + type ipList struct { ips map[string]struct{} } -func newIpList(ip string) ipList { +func newIpList(ip string) *ipList { ips := make(map[string]struct{}) ips[ip] = struct{}{} - return ipList{ + return &ipList{ ips: ips, } } @@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) { s.ips[ip] = struct{}{} } +// MarshalJSON implements json.Marshaler +func (s *ipList) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPs map[string]struct{} `json:"ips"` + }{ + IPs: s.ips, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipList) UnmarshalJSON(data []byte) error { + temp := struct { + IPs map[string]struct{} `json:"ips"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ips = temp.IPs + return nil +} + type ipsetStore struct { - ipsets map[string]ipList // ipsetName -> ruleset + ipsets map[string]*ipList } func newIpsetStore() *ipsetStore { return &ipsetStore{ - ipsets: make(map[string]ipList), + ipsets: make(map[string]*ipList), } } -func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { +func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) { r, ok := s.ipsets[ipsetName] return r, ok } -func (s *ipsetStore) addIpList(ipsetName string, list ipList) { +func (s *ipsetStore) addIpList(ipsetName string, list *ipList) { s.ipsets[ipsetName] = list } func (s *ipsetStore) deleteIpset(ipsetName string) { - s.ipsets[ipsetName] = ipList{} delete(s.ipsets, ipsetName) } @@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string { } return names } + +// MarshalJSON implements json.Marshaler +func (s *ipsetStore) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + IPSets map[string]*ipList `json:"ipsets"` + }{ + IPSets: s.ipsets, + }) +} + +// UnmarshalJSON implements json.Unmarshaler +func (s *ipsetStore) UnmarshalJSON(data []byte) error { + temp := struct { + IPSets map[string]*ipList `json:"ipsets"` + }{} + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + s.ipsets = temp.IPSets + return nil +} diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go new file mode 100644 index 00000000000..44b8340ba75 --- /dev/null +++ b/client/firewall/iptables/state_linux.go @@ -0,0 +1,70 @@ +package iptables + +import ( + "fmt" + "sync" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + sync.Mutex + + InterfaceState *InterfaceState `json:"interface_state,omitempty"` + + RouteRules routeRules `json:"route_rules,omitempty"` + RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"` + + ACLEntries aclEntries `json:"acl_entries,omitempty"` + ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "iptables_state" +} + +func (s *ShutdownState) Cleanup() error { + ipt, err := Create(s.InterfaceState) + if err != nil { + return fmt.Errorf("create iptables manager: %w", err) + } + + if s.RouteRules != nil { + ipt.router.rules = s.RouteRules + } + if s.RouteIPsetCounter != nil { + ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter) + } + + if s.ACLEntries != nil { + ipt.aclMgr.entries = s.ACLEntries + } + if s.ACLIPsetStore != nil { + ipt.aclMgr.ipsetStore = s.ACLIPsetStore + } + + if err := ipt.Reset(nil); err != nil { + return fmt.Errorf("reset iptables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 556bda0d6b1..2a40cd9f68c 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -10,6 +10,8 @@ import ( "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -52,6 +54,8 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { + Init(stateManager *statemanager.Manager) error + // AllowNetbird allows netbird interface traffic AllowNetbird() error @@ -91,7 +95,7 @@ type Manager interface { SetLegacyManagement(legacy bool) error // Reset firewall to the default state - Reset() error + Reset(stateManager *statemanager.Manager) error // Flush the changes to firewall controller Flush() error diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 61434f03518..ca7b2e59fbc 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -17,7 +17,6 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -56,13 +55,6 @@ type AclManager struct { rules map[string]*Rule } -// iFaceMapper defines subset methods of interface required for manager -type iFaceMapper interface { - Name() string - Address() iface.WGAddress - IsUserspaceBind() bool -} - func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) @@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { - return nil, err + return nil, fmt.Errorf("create nf conn: %w", err) } - m := &AclManager{ + return &AclManager{ rConn: &nftables.Conn{}, sConn: sConn, wgIface: wgIface, @@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam ipsetStore: newIpsetStore(), rules: make(map[string]*Rule), - } - - err = m.createDefaultChains() - if err != nil { - return nil, err - } + }, nil +} - return m, nil +func (m *AclManager) init(workTable *nftables.Table) error { + m.workTable = workTable + return m.createDefaultChains() } // AddPeerFiltering rule to the firewall diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 01b08bd7111..a4650f3b626 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -24,6 +26,13 @@ const ( chainNameInput = "INPUT" ) +// iFaceMapper defines subset methods of interface required for manager +type iFaceMapper interface { + Name() string + Address() iface.WGAddress + IsUserspaceBind() bool +} + // Manager of iptables firewall type Manager struct { mutex sync.Mutex @@ -35,30 +44,68 @@ type Manager struct { } // Create nftables firewall manager -func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, } - workTable, err := m.createWorkTable() - if err != nil { - return nil, err - } + workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} - m.router, err = newRouter(context, workTable, wgIface) + var err error + m.router, err = newRouter(workTable, wgIface) if err != nil { - return nil, err + return nil, fmt.Errorf("create router: %w", err) } m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { - return nil, err + return nil, fmt.Errorf("create acl manager: %w", err) } return m, nil } +// Init nftables firewall manager +func (m *Manager) Init(stateManager *statemanager.Manager) error { + workTable, err := m.createWorkTable() + if err != nil { + return fmt.Errorf("create work table: %w", err) + } + + if err := m.router.init(workTable); err != nil { + return fmt.Errorf("router init: %w", err) + } + + if err := m.aclManager.init(workTable); err != nil { + // TODO: cleanup router + return fmt.Errorf("acl manager init: %w", err) + } + + stateManager.RegisterState(&ShutdownState{}) + + // We only need to record minimal interface state for potential recreation. + // Unlike iptables, which requires tracking individual rules, nftables maintains + // a known state (our netbird table plus a few static rules). This allows for easy + // cleanup using Reset() without needing to store specific rules. + if err := stateManager.UpdateState(&ShutdownState{ + InterfaceState: &InterfaceState{ + NameStr: m.wgIface.Name(), + WGAddress: m.wgIface.Address(), + UserspaceBind: m.wgIface.IsUserspaceBind(), + }, + }); err != nil { + log.Errorf("failed to update state: %v", err) + } + + // persist early + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + + return nil +} + // AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set @@ -203,48 +250,80 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() + if err := m.resetNetbirdInputRules(); err != nil { + return fmt.Errorf("reset netbird input rules: %v", err) + } + + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset router: %v", err) + } + + if err := m.cleanupNetbirdTables(); err != nil { + return fmt.Errorf("cleanup netbird tables: %v", err) + } + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + return fmt.Errorf("delete state: %v", err) + } + + return nil +} + +func (m *Manager) resetNetbirdInputRules() error { chains, err := m.rConn.ListChains() if err != nil { - return fmt.Errorf("list of chains: %w", err) + return fmt.Errorf("list chains: %w", err) } + m.deleteNetbirdInputRules(chains) + + return nil +} + +func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - // delete Netbird allow input traffic rule if it exists if c.Table.Name == "filter" && c.Name == "INPUT" { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) continue } - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } + + m.deleteMatchingRules(rules) } } +} - if err := m.router.Reset(); err != nil { - return fmt.Errorf("reset forward rules: %v", err) +func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { + for _, r := range rules { + if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { + if err := m.rConn.DelRule(r); err != nil { + log.Errorf("delete rule: %v", err) + } + } } +} +func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { - return fmt.Errorf("list of tables: %w", err) + return fmt.Errorf("list tables: %w", err) } + for _, t := range tables { if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - - return m.rConn.Flush() + return nil } // Flush rule/chain/set operations from the buffer diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index bbe18ab0714..77f4f03066e 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,7 +1,6 @@ package nftables import ( - "context" "fmt" "net" "net/netip" @@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), ifaceMock) + manager, err := Create(ifaceMock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") time.Sleep(time.Second) }() @@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) { // established rule remains require.Len(t, rules, 1, "expected 1 rules after deletion") - err = manager.Reset() + err = manager.Reset(nil) require.NoError(t, err, "failed to reset") } @@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(mock) require.NoError(t, err) + require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 03526fee7b9..9b28e4eb213 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -2,7 +2,6 @@ package nftables import ( "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -40,8 +39,6 @@ var ( ) type router struct { - ctx context.Context - stop context.CancelFunc conn *nftables.Conn workTable *nftables.Table filterTable *nftables.Table @@ -54,12 +51,8 @@ type router struct { legacyManagement bool } -func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - +func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { r := &router{ - ctx: ctx, - stop: cancel, conn: &nftables.Conn{}, workTable: workTable, chains: make(map[string]*nftables.Chain), @@ -78,20 +71,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa if errors.Is(err, errFilterTableNotFound) { log.Warnf("table 'filter' not found for forward rules") } else { - return nil, err + return nil, fmt.Errorf("load filter table: %w", err) } } - err = r.removeAcceptForwardRules() - if err != nil { + return r, nil +} + +func (r *router) init(workTable *nftables.Table) error { + r.workTable = workTable + + if err := r.removeAcceptForwardRules(); err != nil { log.Errorf("failed to clean up rules from FORWARD chain: %s", err) } - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) + if err := r.createContainers(); err != nil { + return fmt.Errorf("create containers: %w", err) } - return r, err + + return nil } // Reset cleans existing nftables default forward rules from the system diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index c07111b4e10..19ed48991f1 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -3,7 +3,6 @@ package nftables import ( - "context" "encoding/binary" "net/netip" "os/exec" @@ -40,8 +39,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -142,8 +142,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table, ifaceMock) + manager, err := newRouter(table, ifaceMock) require.NoError(t, err, "failed to create router") + require.NoError(t, manager.init(table)) nftablesTestingClient := &nftables.Conn{} @@ -210,8 +211,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func(r *router) { require.NoError(t, r.Reset(), "Failed to reset rules") @@ -376,8 +378,9 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(context.Background(), workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock) require.NoError(t, err, "Failed to create router") + require.NoError(t, r.init(workTable)) defer func() { require.NoError(t, r.Reset(), "Failed to reset router") diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go new file mode 100644 index 00000000000..a68c8b8b882 --- /dev/null +++ b/client/firewall/nftables/state_linux.go @@ -0,0 +1,47 @@ +package nftables + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +type InterfaceState struct { + NameStr string `json:"name"` + WGAddress iface.WGAddress `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` +} + +func (i *InterfaceState) Name() string { + return i.NameStr +} + +func (i *InterfaceState) Address() device.WGAddress { + return i.WGAddress +} + +func (i *InterfaceState) IsUserspaceBind() bool { + return i.UserspaceBind +} + +type ShutdownState struct { + InterfaceState *InterfaceState `json:"interface_state,omitempty"` +} + +func (s *ShutdownState) Name() string { + return "nftables_state" +} + +func (s *ShutdownState) Cleanup() error { + nft, err := Create(s.InterfaceState) + if err != nil { + return fmt.Errorf("create nftables manager: %w", err) + } + + if err := nft.Reset(nil); err != nil { + return fmt.Errorf("reset nftables manager: %w", err) + } + + return nil +} diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 2275dad3998..cefc81a3ce6 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,8 +2,10 @@ package uspfilter +import "github.com/netbirdio/netbird/client/internal/statemanager" + // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -11,7 +13,7 @@ func (m *Manager) Reset() error { m.incomingRules = make(map[string]RuleSet) if m.nativeFirewall != nil { - return m.nativeFirewall.Reset() + return m.nativeFirewall.Reset(stateManager) } return nil } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 34274564fa3..d3732301ed5 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -6,6 +6,8 @@ import ( "syscall" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type action string @@ -17,7 +19,7 @@ const ( ) // Reset firewall to the default state -func (m *Manager) Reset() error { +func (m *Manager) Reset(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 0e3ee97991f..3829a9baffe 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/statemanager" ) const layerTypeAll = 0 @@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) { return m, nil } +func (m *Manager) Init(*statemanager.Manager) error { + return nil +} + func (m *Manager) IsServerRouteSupported() bool { if m.nativeFirewall == nil { return false @@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errRouteNotSupported } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index c188deea460..d7c93cb7f99 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) { return } - err = m.Reset() + err = m.Reset(nil) if err != nil { t.Errorf("failed to reset Manager: %v", err) return @@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if err = m.Reset(); err != nil { + if err = m.Reset(nil); err != nil { t.Errorf("failed to reset Manager: %v", err) return } @@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { time.Sleep(time.Second) defer func() { - if err := manager.Reset(); err != nil { + if err := manager.Reset(nil); err != nil { t.Errorf("clear the manager state: %v", err) } time.Sleep(time.Second) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 7d999669abb..9a766021a45 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -1,7 +1,6 @@ package acl import ( - "context" "net" "testing" @@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) @@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(context.Background(), ifaceMock) + fw, err := firewall.NewFirewall(ifaceMock, nil) if err != nil { t.Errorf("create firewall: %v", err) return } defer func(fw manager.Manager) { - _ = fw.Reset() + _ = fw.Reset(nil) }(fw) acl := NewDefaultManager(fw) diff --git a/client/internal/connect.go b/client/internal/connect.go index 13f10fbf1e6..bcc9d17a3f6 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error { } // RunWithProbes runs the client's main logic with probes attached -func (c *ConnectClient) RunWithProbes( - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error { return c.run(MobileDependency{}, probes, runningChan) } @@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS( return c.run(mobileDependency, nil, nil) } -func (c *ConnectClient) run( - mobileDependency MobileDependency, - probes *ProbeHolder, - runningChan chan error, -) error { +func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error { defer func() { if r := recover(); r != nil { log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 772797fac0a..929e1e60c85 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -533,6 +533,13 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + if err := s.stateManager.PersistState(ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() } diff --git a/client/internal/engine.go b/client/internal/engine.go index af2817e6ed3..190d795cdbe 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,6 +38,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -366,7 +367,7 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) } @@ -1167,7 +1168,7 @@ func (e *Engine) close() { } if e.firewall != nil { - err := e.firewall.Reset() + err := e.firewall.Reset(e.stateManager) if err != nil { log.Warnf("failed to reset firewall: %s", err) } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 65ea0f708ea..c121b7d774b 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -1,6 +1,7 @@ package refcounter import ( + "encoding/json" "errors" "fmt" "runtime" @@ -70,6 +71,19 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } } +// LoadData loads the data from the existing counter +func (rm *Counter[Key, I, O]) LoadData( + existingCounter *Counter[Key, I, O], +) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + rm.refCountMap = existingCounter.refCountMap + rm.idMap = existingCounter.idMap +} + // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { @@ -201,6 +215,32 @@ func (rm *Counter[Key, I, O]) Clear() { clear(rm.idMap) } +// MarshalJSON implements the json.Marshaler interface for Counter. +func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + }{ + RefCountMap: rm.refCountMap, + IDMap: rm.idMap, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Counter. +func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + var temp struct { + RefCountMap map[Key]Ref[O] `json:"refCountMap"` + IDMap map[string][]Key `json:"idMap"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + rm.refCountMap = temp.RefCountMap + rm.idMap = temp.IDMap + + return nil +} + func getCallerInfo(depth int, maxDepth int) (string, bool) { if depth >= maxDepth { return "", false diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 26992467750..42590892297 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -1,30 +1,15 @@ package systemops import ( - "encoding/json" - "fmt" "net/netip" "sync" - "github.com/hashicorp/go-multierror" - - nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type RouteEntry struct { - Prefix netip.Prefix `json:"prefix"` - Nexthop Nexthop `json:"nexthop"` -} - type ShutdownState struct { - Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` - mu sync.RWMutex -} - -func NewShutdownState() *ShutdownState { - return &ShutdownState{ - Routes: make(map[netip.Prefix]RouteEntry), - } + Counter *ExclusionCounter `json:"counter,omitempty"` + mu sync.RWMutex } func (s *ShutdownState) Name() string { @@ -32,50 +17,16 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - var merr *multierror.Error - - s.mu.RLock() - defer s.mu.RUnlock() - - for _, route := range s.Routes { - if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err)) - } - } - - return nberrors.FormatErrorOrNil(merr) -} - -func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) { - s.mu.Lock() - defer s.mu.Unlock() - - s.Routes[prefix] = RouteEntry{ - Prefix: prefix, - Nexthop: nexthop, - } -} - -func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) { - s.mu.Lock() - defer s.mu.Unlock() - - delete(s.Routes, prefix) -} - -// MarshalJSON ensures that empty routes are marshaled as null -func (s *ShutdownState) MarshalJSON() ([]byte, error) { s.mu.RLock() defer s.mu.RUnlock() - if len(s.Routes) == 0 { - return json.Marshal(nil) + if s.Counter == nil { + return nil } - return json.Marshal(s.Routes) -} + sysops := NewSysOps(nil, nil) + sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) + sysops.refCounter.LoadData(s.Counter) -func (s *ShutdownState) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &s.Routes) + return sysops.refCounter.Flush() } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 2b8a14ea2d2..4ff34aa5162 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -57,14 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager, prefix, nexthop) + r.updateState(stateManager) return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { // remove from state even if we have trouble removing it from the route table // it could be already gone - r.removeFromState(stateManager, prefix) + r.updateState(stateManager) return r.removeFromRouteTable(prefix, nexthop) }, @@ -75,21 +75,13 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } -func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { +func (r *SysOps) updateState(stateManager *statemanager.Manager) { state := getState(stateManager) - state.UpdateRoute(prefix, nexthop) - if err := stateManager.UpdateState(state); err != nil { - log.Errorf("failed to update state: %v", err) - } -} - -func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) { - state := getState(stateManager) - state.RemoveRoute(prefix) + state.Counter = r.refCounter if err := stateManager.UpdateState(state); err != nil { - log.Errorf("Failed to update state: %v", err) + log.Errorf("failed to update state: %v", err) } } @@ -107,7 +99,7 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { } if err := stateManager.DeleteState(&ShutdownState{}); err != nil { - log.Errorf("failed to delete state: %v", err) + return fmt.Errorf("delete state: %w", err) } return nil @@ -546,7 +538,7 @@ func getState(stateManager *statemanager.Manager) *ShutdownState { if state := stateManager.GetState(shutdownState); state != nil { shutdownState = state.(*ShutdownState) } else { - shutdownState = NewShutdownState() + shutdownState = &ShutdownState{} } return shutdownState diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 64c5316d871..96d6a9f12d3 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -5,7 +5,7 @@ import ( "path/filepath" "runtime" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system @@ -27,7 +27,7 @@ func GetDefaultStatePath() string { dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0755); err != nil { - logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) return "" } diff --git a/client/server/server.go b/client/server/server.go index 342f61b883f..a0332208194 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,7 +11,6 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -21,11 +20,7 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" - "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -848,31 +843,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. -// Otherwise, we might not be able to connect to the management server to retrieve new config. -func restoreResidualState(ctx context.Context) error { - path := statemanager.GetDefaultStatePath() - if path == "" { - return nil - } - - mgr := statemanager.New(path) - - var merr *multierror.Error - - // register the states we are interested in restoring - // this will also allow each subsystem to record its own state - mgr.RegisterState(&dns.ShutdownState{}) - mgr.RegisterState(&systemops.ShutdownState{}) - - if err := mgr.PerformCleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) - } - - if err := mgr.PersistState(ctx); err != nil { - merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) - } - - return nberrors.FormatErrorOrNil(merr) -} diff --git a/client/server/state.go b/client/server/state.go new file mode 100644 index 00000000000..509782e86c7 --- /dev/null +++ b/client/server/state.go @@ -0,0 +1,37 @@ +package server + +import ( + "context" + "fmt" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + // register the states we are interested in restoring + registerStates(mgr) + + var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + // persist state regardless of cleanup outcome. It could've succeeded partially + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/server/state_generic.go b/client/server/state_generic.go new file mode 100644 index 00000000000..e6c7bdd44d7 --- /dev/null +++ b/client/server/state_generic.go @@ -0,0 +1,14 @@ +//go:build !linux || android + +package server + +import ( + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) +} diff --git a/client/server/state_linux.go b/client/server/state_linux.go new file mode 100644 index 00000000000..08762890719 --- /dev/null +++ b/client/server/state_linux.go @@ -0,0 +1,18 @@ +//go:build !android + +package server + +import ( + "github.com/netbirdio/netbird/client/firewall/iptables" + "github.com/netbirdio/netbird/client/firewall/nftables" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func registerStates(mgr *statemanager.Manager) { + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + mgr.RegisterState(&nftables.ShutdownState{}) + mgr.RegisterState(&iptables.ShutdownState{}) +} From 0fd874fa455ecabb9135a1810a496f248156ed3d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 28 Oct 2024 10:02:27 +0100 Subject: [PATCH 61/81] [client] Make native firewall init fail firewall creation (#2784) --- client/firewall/create_linux.go | 62 +++++++++++++++------------------ 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index c853548f841..076d08ec27b 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -3,6 +3,7 @@ package firewall import ( + "errors" "fmt" "os" @@ -37,62 +38,55 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, errFw := createNativeFirewall(iface) + fm, err := createNativeFirewall(iface, stateManager) - if fm != nil { - if err := fm.Init(stateManager); err != nil { - log.Errorf("failed to init nftables manager: %s", err) - } + if !iface.IsUserspaceBind() { + return fm, err + } + + if err != nil { + log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } + return createUserspaceFirewall(iface, fm) +} - if iface.IsUserspaceBind() { - return createUserspaceFirewall(iface, fm, errFw) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { + fm, err := createFW(iface) + if err != nil { + return nil, fmt.Errorf("create firewall: %s", err) } - return fm, errFw + if err = fm.Init(stateManager); err != nil { + return nil, fmt.Errorf("init firewall: %s", err) + } + + return fm, nil } -func createNativeFirewall(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper) (firewall.Manager, error) { switch check() { case IPTABLES: - return createIptablesFirewall(iface) + log.Info("creating an iptables firewall manager") + return nbiptables.Create(iface) case NFTABLES: - return createNftablesFirewall(iface) + log.Info("creating an nftables firewall manager") + return nbnftables.Create(iface) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") - return nil, fmt.Errorf("no firewall manager found") + return nil, errors.New("no firewall manager found") } } -func createIptablesFirewall(iface IFaceMapper) (firewall.Manager, error) { - log.Info("creating an iptables firewall manager") - fm, err := nbiptables.Create(iface) - if err != nil { - log.Errorf("failed to create iptables manager: %s", err) - } - return fm, err -} - -func createNftablesFirewall(iface IFaceMapper) (firewall.Manager, error) { - log.Info("creating an nftables firewall manager") - fm, err := nbnftables.Create(iface) - if err != nil { - log.Errorf("failed to create nftables manager: %s", err) - } - return fm, err -} - -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, errFw error) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { var errUsp error - if errFw == nil { + if fm != nil { fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) } else { fm, errUsp = uspfilter.Create(iface) } if errUsp != nil { - log.Debugf("failed to create userspace filtering firewall: %s", errUsp) - return nil, errUsp + return nil, fmt.Errorf("create userspace firewall: %s", errUsp) } if err := fm.AllowNetbird(); err != nil { From b9f205b2ce7943a01fe59c63eb6584fb9bdefd04 Mon Sep 17 00:00:00 2001 From: Stefano Date: Mon, 28 Oct 2024 10:08:17 +0100 Subject: [PATCH 62/81] [misc] Update Zitadel from v2.54.10 to v2.64.1 --- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 16b2364fb56..0b2b651429e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -873,7 +873,7 @@ services: zitadel: restart: 'always' networks: [netbird] - image: 'ghcr.io/zitadel/zitadel:v2.54.10' + image: 'ghcr.io/zitadel/zitadel:v2.64.1' command: 'start-from-init --masterkeyFromEnv --tlsMode $ZITADEL_TLS_MODE' env_file: - ./zitadel.env From 46e37fa04c66117be4c0a561d30b5384ccf5ab66 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:28:44 +0100 Subject: [PATCH 63/81] [client] Ignore route rules with no sources instead of erroring out (#2786) --- client/internal/acl/manager.go | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index ce2a12af16f..5bb0905d2a7 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -3,6 +3,7 @@ package acl import ( "crypto/md5" "encoding/hex" + "errors" "fmt" "net" "net/netip" @@ -10,14 +11,18 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) +var ErrSourceRangesEmpty = errors.New("sources range is empty") + // Manager is a ACL rules manager type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap) @@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { } func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { - var newRouteRules = make(map[id.RuleID]struct{}) + newRouteRules := make(map[id.RuleID]struct{}, len(rules)) + var merr *multierror.Error + + // Apply new rules - firewall manager will return existing rule ID if already present for _, rule := range rules { id, err := d.applyRouteACL(rule) if err != nil { - return fmt.Errorf("apply route ACL: %w", err) + if errors.Is(err, ErrSourceRangesEmpty) { + log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) + } else { + merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) + } + continue } newRouteRules[id] = struct{}{} } + // Clean up old firewall rules for id := range d.routeRules { - if _, ok := newRouteRules[id]; !ok { + if _, exists := newRouteRules[id]; !exists { if err := d.firewall.DeleteRouteRule(id); err != nil { - log.Errorf("failed to delete route firewall rule: %v", err) - continue + merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err)) } - delete(d.routeRules, id) + // implicitly deleted from the map } } + d.routeRules = newRouteRules - return nil + return nberrors.FormatErrorOrNil(merr) } func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { if len(rule.SourceRanges) == 0 { - return "", fmt.Errorf("source ranges is empty") + return "", ErrSourceRangesEmpty } var sources []netip.Prefix From 940f8b454754c8a0e8184860cfc20f1bbf332d22 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 28 Oct 2024 12:29:29 +0100 Subject: [PATCH 64/81] [client] Remove legacy forwarding rules in userspace mode (#2782) --- client/firewall/iptables/router_linux.go | 2 ++ client/firewall/nftables/manager_linux.go | 18 +----------------- client/firewall/nftables/router_linux.go | 3 +++ client/firewall/uspfilter/uspfilter.go | 7 +++++-- 4 files changed, 11 insertions(+), 19 deletions(-) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 90811ae1182..9b75640b4b5 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -296,6 +296,8 @@ func (r *router) RemoveAllLegacyRouteRules() error { } if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } else { + delete(r.rules, k) } } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a4650f3b626..ea8912f27f5 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -230,23 +230,7 @@ func (m *Manager) AllowNetbird() error { // SetLegacyManagement sets the route manager to use legacy management func (m *Manager) SetLegacyManagement(isLegacy bool) error { - oldLegacy := m.router.legacyManagement - - if oldLegacy != isLegacy { - m.router.legacyManagement = isLegacy - log.Debugf("Set legacy management to %v", isLegacy) - } - - // client reconnected to a newer mgmt, we need to cleanup the legacy rules - if !isLegacy && oldLegacy { - if err := m.router.RemoveAllLegacyRouteRules(); err != nil { - return fmt.Errorf("remove legacy routing rules: %v", err) - } - - log.Debugf("Legacy routing rules removed") - } - - return nil + return firewall.SetLegacyManagement(m.router, isLegacy) } // Reset firewall to the default state diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 9b28e4eb213..0e7ea71b774 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -551,7 +551,10 @@ func (r *router) RemoveAllLegacyRouteRules() error { } if err := r.conn.DelRule(rule); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } else { + delete(r.rules, k) } + } return nberrors.FormatErrorOrNil(merr) } diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 3829a9baffe..af5dc673393 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -237,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { } // SetLegacyManagement doesn't need to be implemented for this manager -func (m *Manager) SetLegacyManagement(_ bool) error { - return nil +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.SetLegacyManagement(isLegacy) } // Flush doesn't need to be implemented for this manager From 1e44c5b5747409c2e34fb3e4ae65bf26697adc3b Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:55:00 +0100 Subject: [PATCH 65/81] [client] allow relay leader on iOS (#2795) --- client/internal/peer/ice/agent.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index b2a9669367e..dc4750f243a 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,13 +1,14 @@ package ice import ( - "github.com/netbirdio/netbird/client/internal/stdnet" + "time" + "github.com/pion/ice/v3" "github.com/pion/randutil" "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" - "runtime" - "time" + + "github.com/netbirdio/netbird/client/internal/stdnet" ) const ( @@ -77,10 +78,7 @@ func CandidateTypes() []ice.CandidateType { if hasICEForceRelayConn() { return []ice.CandidateType{ice.CandidateTypeRelay} } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} } From 10480eb52f305cbe235ba7c63349f7e89db42bbc Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:52:23 +0100 Subject: [PATCH 66/81] [management] Setup key improvements (#2775) --- management/server/account.go | 1 + management/server/account_test.go | 9 - management/server/activity/codes.go | 3 + management/server/http/api/openapi.yml | 31 ++- management/server/http/api/types.gen.go | 2 +- management/server/http/handler.go | 1 + management/server/http/peers_handler_test.go | 4 +- management/server/http/setupkeys_handler.go | 43 +++- .../server/http/setupkeys_handler_test.go | 25 ++- management/server/metrics/selfhosted.go | 2 +- management/server/migration/migration.go | 90 +++++++++ management/server/migration/migration_test.go | 69 +++++++ management/server/mock_server/account_mock.go | 8 + management/server/peer.go | 9 +- management/server/peer/peer.go | 35 ++-- management/server/peer_test.go | 20 +- management/server/setupkey.go | 101 ++++++---- management/server/setupkey_test.go | 64 ++++-- management/server/sql_store.go | 28 ++- management/server/sql_store_test.go | 185 ++++++++++-------- management/server/status/error.go | 10 + management/server/store.go | 4 + 22 files changed, 549 insertions(+), 195 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index a8a244bdf1f..1810c6b41ec 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -153,6 +153,7 @@ type AccountManager interface { FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error } type DefaultAccountManager struct { diff --git a/management/server/account_test.go b/management/server/account_test.go index 3c3fcebc67f..1cd4ae449db 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1010,7 +1010,6 @@ func TestAccountManager_AddPeer(t *testing.T) { return } expectedPeerKey := key.PublicKey().String() - expectedSetupKey := setupKey.Key peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, @@ -1035,10 +1034,6 @@ func TestAccountManager_AddPeer(t *testing.T) { t.Errorf("expecting just added peer's IP %s to be in a network range %s", peer.IP.String(), account.Network.Net.String()) } - if peer.SetupKey != expectedSetupKey { - t.Errorf("expecting just added peer to have SetupKey = %s, got %s", expectedSetupKey, peer.SetupKey) - } - if account.Network.CurrentSerial() != 1 { t.Errorf("expecting Network Serial=%d to be incremented by 1 and be equal to %d when adding new peer to account", serial, account.Network.CurrentSerial()) } @@ -2367,7 +2362,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - SetupKey: "key", }, "peer-2": { Status: &nbpeer.PeerStatus{ @@ -2375,7 +2369,6 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - SetupKey: "key", }, }, expiration: time.Second, @@ -2529,7 +2522,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LoginExpired: false, }, InactivityExpirationEnabled: true, - SetupKey: "key", }, "peer-2": { Status: &nbpeer.PeerStatus{ @@ -2537,7 +2529,6 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LoginExpired: false, }, InactivityExpirationEnabled: true, - SetupKey: "key", }, }, expiration: time.Second, diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 188494241c6..603260dbcb2 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -146,6 +146,8 @@ const ( AccountPeerInactivityExpirationEnabled Activity = 65 AccountPeerInactivityExpirationDisabled Activity = 66 AccountPeerInactivityExpirationDurationUpdated Activity = 67 + + SetupKeyDeleted Activity = 68 ) var activityMap = map[Activity]Code{ @@ -219,6 +221,7 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationEnabled: {"Account peer inactivity expiration enabled", "account.peer.inactivity.expiration.enable"}, AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, + SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 9d51482481a..9b4592ccf10 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -530,10 +530,9 @@ components: type: string example: reusable expires_in: - description: Expiration time in seconds + description: Expiration time in seconds, 0 will mean the key never expires type: integer - minimum: 86400 - maximum: 31536000 + minimum: 0 example: 86400 revoked: description: Setup key revocation status @@ -2018,6 +2017,32 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Setup Key + description: Delete a Setup Key + tags: [ Setup Keys ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: keyId + required: true + schema: + type: string + description: The unique identifier of a setup key + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/groups: get: summary: List all Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index e2870d5d8ef..c1ef1ba2122 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1101,7 +1101,7 @@ type SetupKeyRequest struct { // Ephemeral Indicate that the peer will be ephemeral or not Ephemeral *bool `json:"ephemeral,omitempty"` - // ExpiresIn Expiration time in seconds + // ExpiresIn Expiration time in seconds, 0 will mean the key never expires ExpiresIn int `json:"expires_in"` // Name Setup Key name diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3f8a8554d07..c3928bff681 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -141,6 +141,7 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() { apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") + apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") } func (apiHandler *apiHandler) addPoliciesEndpoint() { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index f933eee1497..dd49c03b848 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -13,12 +13,13 @@ import ( "time" "github.com/gorilla/mux" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "golang.org/x/exp/maps" "github.com/stretchr/testify/assert" @@ -168,7 +169,6 @@ func TestGetPeers(t *testing.T) { peer := &nbpeer.Peer{ ID: testPeerID, Key: "key", - SetupKey: "setupkey", IP: net.ParseIP("100.64.0.1"), Status: &nbpeer.PeerStatus{Connected: true}, Name: "PeerName", diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8514f0b556b..31859f59bf0 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -61,10 +61,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request expiresIn := time.Duration(req.ExpiresIn) * time.Second - day := time.Hour * 24 - year := day * 365 - if expiresIn < day || expiresIn > year { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn should be between 1 day and 365 days"), w) + if expiresIn < 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "expiresIn can not be in the past"), w) return } @@ -76,6 +74,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { @@ -83,7 +82,11 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request return } - writeSuccess(r.Context(), w, setupKey) + apiSetupKeys := toResponseBody(setupKey) + // for the creation we need to send the plain key + apiSetupKeys.Key = setupKey.Key + + util.WriteJSONObject(r.Context(), w, apiSetupKeys) } // GetSetupKey is a GET request to get a SetupKey by ID @@ -98,7 +101,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) return } @@ -123,7 +126,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid key ID"), w) + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) return } @@ -181,6 +184,30 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } +func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + keyID := vars["keyId"] + if len(keyID) == 0 { + util.WriteError(r.Context(), status.NewInvalidKeyIDError(), w) + return + } + + err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, emptyObject{}) +} + func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) @@ -206,7 +233,7 @@ func toResponseBody(key *server.SetupKey) *api.SetupKey { return &api.SetupKey{ Id: key.Id, - Key: key.Key, + Key: key.KeySecret, Name: key.Name, Expires: key.ExpiresAt, Type: string(key.Type), diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 2d15287af25..09256d0ea5e 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -67,6 +67,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { return []*server.SetupKey{defaultKey}, nil }, + + DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { + if keyID == defaultKey.Id { + return nil + } + return status.Errorf(status.NotFound, "key %s not found", keyID) + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -81,18 +88,21 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := server.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID adminUser := server.NewAdminUser("test_user") - newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, + newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, server.SetupKeyUnlimitedUsage, true) + newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true + expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey.Key = plainKey tt := []struct { name string requestType string @@ -134,7 +144,7 @@ func TestSetupKeysHandlers(t *testing.T) { []byte(fmt.Sprintf("{\"name\":\"%s\",\"type\":\"%s\",\"expires_in\":86400, \"ephemeral\":true}", newSetupKey.Name, newSetupKey.Type))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(newSetupKey), + expectedSetupKey: expectedNewKey, }, { name: "Update Setup Key", @@ -150,6 +160,14 @@ func TestSetupKeysHandlers(t *testing.T) { expectedBody: true, expectedSetupKey: toResponseBody(updatedDefaultSetupKey), }, + { + name: "Delete Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/" + defaultSetupKey.Id, + requestBody: bytes.NewBuffer([]byte("")), + expectedStatus: http.StatusOK, + expectedBody: false, + }, } handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) @@ -164,6 +182,7 @@ func TestSetupKeysHandlers(t *testing.T) { router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index bdf744d211e..843fa575e83 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -267,7 +267,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties { peersSSHEnabled++ } - if peer.SetupKey == "" { + if peer.UserID != "" { userPeers++ } diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 4c8baea5e87..6f12d94b401 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -2,13 +2,16 @@ package migration import ( "context" + "crypto/sha256" "database/sql" + b64 "encoding/base64" "encoding/gob" "encoding/json" "errors" "fmt" "net" "strings" + "unicode/utf8" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -205,3 +208,90 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi return nil } + +func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error { + oldColumnName := "key" + newColumnName := "key_secret" + + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if err := db.Transaction(func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&model, newColumnName) { + log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", newColumnName, tableName) + if err := tx.Migrator().AddColumn(&model, newColumnName); err != nil { + return fmt.Errorf("add column %s: %w", newColumnName, err) + } + } + + var rows []map[string]any + if err := tx.Table(tableName). + Select("id", oldColumnName, newColumnName). + Where(newColumnName + " IS NULL OR " + newColumnName + " = ''"). + Where("SUBSTR(" + oldColumnName + ", 9, 1) = '-'"). + Find(&rows).Error; err != nil { + return fmt.Errorf("find rows with empty secret key and matching pattern: %w", err) + } + + if len(rows) == 0 { + log.WithContext(ctx).Infof("No plain setup keys found in table %s, no migration needed", tableName) + return nil + } + + for _, row := range rows { + var plainKey string + if columnValue := row[oldColumnName]; columnValue != nil { + value, ok := columnValue.(string) + if !ok { + return fmt.Errorf("type assertion failed") + } + plainKey = value + } + + secretKey := hiddenKey(plainKey, 4) + + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(newColumnName, secretKey).Error; err != nil { + return fmt.Errorf("update row with secret key: %w", err) + } + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(oldColumnName, encodedHashedKey).Error; err != nil { + return fmt.Errorf("update row with hashed key: %w", err) + } + } + + if err := tx.Exec(fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", "peers", "setup_key")).Error; err != nil { + log.WithContext(ctx).Errorf("Failed to drop column %s: %v", "setup_key", err) + } + + return nil + }); err != nil { + return err + } + + log.Printf("Migration of plain setup key to hashed setup key completed") + return nil +} + +// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func hiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 5a192664169..51358c7ad67 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -160,3 +160,72 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db.Model(&nbpeer.Peer{}).Select("location_connection_ip").First(&jsonStr) assert.JSONEq(t, `"10.0.0.1"`, jsonStr, "Data should be unchanged") } + +func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} + +func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + KeySecret: "EEFDA****", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} + +func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { + db := setupDatabase(t) + + err := db.AutoMigrate(&server.SetupKey{}) + require.NoError(t, err, "Failed to auto-migrate tables") + + err = db.Save(&server.SetupKey{ + Id: "1", + Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", + }).Error + require.NoError(t, err, "Failed to insert setup key") + + err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + require.NoError(t, err, "Migration should not fail to migrate setup key") + + var key server.SetupKey + err = db.Model(&server.SetupKey{}).First(&key).Error + assert.NoError(t, err, "Failed to fetch setup key") + + assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 681bf533ae4..d7139bb2a5f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -109,6 +109,14 @@ type MockAccountManager struct { GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error +} + +func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + if am.DeleteSetupKeyFunc != nil { + return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) + } + return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { diff --git a/management/server/peer.go b/management/server/peer.go index 80d43497a70..96ede151158 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "net" "slices" @@ -396,6 +398,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } upperKey := strings.ToUpper(setupKey) + hashedKey := sha256.Sum256([]byte(upperKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) var accountID string var err error addedByUser := false @@ -403,7 +407,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser = true accountID, err = am.Store.GetAccountIDByUserID(userID) } else { - accountID, err = am.Store.GetAccountIDBySetupKey(ctx, setupKey) + accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } if err != nil { return nil, nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") @@ -448,7 +452,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -489,7 +493,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s ID: xid.New().String(), AccountID: accountID, Key: peer.Key, - SetupKey: upperKey, IP: freeIP, Meta: peer.Meta, Name: peer.Meta.Hostname, diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index ef96bce7dd8..1ff67da1231 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -16,8 +16,6 @@ type Peer struct { AccountID string `json:"-" gorm:"index"` // WireGuard public key Key string `gorm:"index"` - // A setup key this peer was registered with - SetupKey string `diff:"-"` // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data @@ -172,23 +170,22 @@ func (p *Peer) Copy() *Peer { peerStatus = p.Status.Copy() } return &Peer{ - ID: p.ID, - AccountID: p.AccountID, - Key: p.Key, - SetupKey: p.SetupKey, - IP: p.IP, - Meta: p.Meta, - Name: p.Name, - DNSLabel: p.DNSLabel, - Status: peerStatus, - UserID: p.UserID, - SSHKey: p.SSHKey, - SSHEnabled: p.SSHEnabled, - LoginExpirationEnabled: p.LoginExpirationEnabled, - LastLogin: p.LastLogin, - CreatedAt: p.CreatedAt, - Ephemeral: p.Ephemeral, - Location: p.Location, + ID: p.ID, + AccountID: p.AccountID, + Key: p.Key, + IP: p.IP, + Meta: p.Meta, + Name: p.Name, + DNSLabel: p.DNSLabel, + Status: peerStatus, + UserID: p.UserID, + SSHKey: p.SSHKey, + SSHEnabled: p.SSHEnabled, + LoginExpirationEnabled: p.LoginExpirationEnabled, + LastLogin: p.LastLogin, + CreatedAt: p.CreatedAt, + Ephemeral: p.Ephemeral, + Location: p.Location, InactivityExpirationEnabled: p.InactivityExpirationEnabled, } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 7b2180bf019..5127f77fbe6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "io" "net" @@ -1090,7 +1092,6 @@ func Test_RegisterPeerByUser(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ Hostname: "newPeer", @@ -1155,7 +1156,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "existingSetupKey", UserID: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ @@ -1175,7 +1175,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - assert.Equal(t, peer.SetupKey, existingSetupKeyID) account, err := store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) @@ -1187,8 +1186,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed) - assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) + + hashedKey := sha256.Sum256([]byte(existingSetupKeyID)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + assert.NotEqual(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed) + assert.Equal(t, 1, account.SetupKeys[encodedHashedKey].UsedTimes) } @@ -1221,7 +1223,6 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ID: xid.New().String(), AccountID: existingAccountID, Key: "newPeerKey", - SetupKey: "existingSetupKey", UserID: "", IP: net.IP{123, 123, 123, 123}, Meta: nbpeer.PeerSystemMeta{ @@ -1250,8 +1251,11 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed.UTC()) - assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) + + hashedKey := sha256.Sum256([]byte(faultyKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC()) + assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes) } func TestPeerAccountPeersUpdate(t *testing.T) { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index e84f8fcd687..43b6e02c936 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,6 +2,9 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" + "fmt" "hash/fnv" "strconv" "strings" @@ -73,6 +76,7 @@ type SetupKey struct { // AccountID is a reference to Account that this object belongs AccountID string `json:"-" gorm:"index"` Key string + KeySecret string Name string Type SetupKeyType CreatedAt time.Time @@ -104,6 +108,7 @@ func (key *SetupKey) Copy() *SetupKey { Id: key.Id, AccountID: key.AccountID, Key: key.Key, + KeySecret: key.KeySecret, Name: key.Name, Type: key.Type, CreatedAt: key.CreatedAt, @@ -120,19 +125,17 @@ func (key *SetupKey) Copy() *SetupKey { // EventMeta returns activity event meta related to the setup key func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.HiddenCopy(1).Key} + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} } -// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix. +// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. // E.g., "831F6*******************************" -func (key *SetupKey) HiddenCopy(length int) *SetupKey { - k := key.Copy() - prefix := k.Key[0:5] - if length > utf8.RuneCountInString(key.Key) { - length = utf8.RuneCountInString(key.Key) - len(prefix) - } - k.Key = prefix + strings.Repeat("*", length) - return k +func hiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) } // IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now @@ -155,6 +158,9 @@ func (key *SetupKey) IsRevoked() bool { // IsExpired if key was expired func (key *SetupKey) IsExpired() bool { + if key.ExpiresAt.IsZero() { + return false + } return time.Now().After(key.ExpiresAt) } @@ -169,30 +175,40 @@ func (key *SetupKey) IsOverUsed() bool { // GenerateSetupKey generates a new setup key func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) *SetupKey { + usageLimit int, ephemeral bool) (*SetupKey, string) { key := strings.ToUpper(uuid.New().String()) limit := usageLimit if t == SetupKeyOneOff { limit = 1 } + + expiresAt := time.Time{} + if validFor != 0 { + expiresAt = time.Now().UTC().Add(validFor) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + return &SetupKey{ Id: strconv.Itoa(int(Hash(key))), - Key: key, + Key: encodedHashedKey, + KeySecret: hiddenKey(key, 4), Name: name, Type: t, CreatedAt: time.Now().UTC(), - ExpiresAt: time.Now().UTC().Add(validFor), + ExpiresAt: expiresAt, UpdatedAt: time.Now().UTC(), Revoked: false, UsedTimes: 0, AutoGroups: autoGroups, UsageLimit: limit, Ephemeral: ephemeral, - } + }, key } // GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() *SetupKey { +func GenerateDefaultSetupKey() (*SetupKey, string) { return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, SetupKeyUnlimitedUsage, false) } @@ -213,11 +229,6 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - keyDuration := DefaultSetupKeyDuration - if expiresIn != 0 { - keyDuration = expiresIn - } - account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -227,7 +238,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, err } - setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral) + setupKey, plainKey := GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) account.SetupKeys[setupKey.Key] = setupKey err = am.Store.SaveAccount(ctx, account) if err != nil { @@ -246,6 +257,9 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s } } + // for the creation return the plain key to the caller + setupKey.Key = plainKey + return setupKey, nil } @@ -334,7 +348,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + return nil, status.NewUnauthorizedToViewSetupKeysError() } setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) @@ -342,18 +356,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, err } - keys := make([]*SetupKey, 0, len(setupKeys)) - for _, key := range setupKeys { - var k *SetupKey - if !user.IsAdminOrServiceUser() { - k = key.HiddenCopy(999) - } else { - k = key.Copy() - } - keys = append(keys, k) - } - - return keys, nil + return setupKeys, nil } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -364,7 +367,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use } if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + return nil, status.NewUnauthorizedToViewSetupKeysError() } setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) @@ -377,11 +380,33 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use setupKey.UpdatedAt = setupKey.CreatedAt } - if !user.IsAdminOrServiceUser() { - setupKey = setupKey.HiddenCopy(999) + return setupKey, nil +} + +// DeleteSetupKey removes the setup key from the account +func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return fmt.Errorf("failed to get user: %w", err) } - return setupKey, nil + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return status.NewUnauthorizedToViewSetupKeysError() + } + + deletedSetupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } + + err = am.Store.DeleteSetupKey(ctx, accountID, keyID) + if err != nil { + return fmt.Errorf("failed to delete setup key: %w", err) + } + + am.StoreEvent(ctx, userID, keyID, accountID, activity.SetupKeyDeleted, deletedSetupKey.EventMeta()) + + return nil } func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 651b5401047..2ed8aef95c6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -2,8 +2,11 @@ package server import ( "context" + "crypto/sha256" + "encoding/base64" "fmt" "strconv" + "strings" "testing" "time" @@ -66,7 +69,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, - key.Id, time.Now().UTC(), autoGroups) + key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) @@ -183,7 +186,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), - tCase.expectedUpdatedAt, tCase.expectedGroups) + tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) @@ -239,10 +242,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key := GenerateDefaultSetupKey() + key, plainKey := GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) + expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -256,41 +259,41 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), expectedUpdatedAt, expectedAutoGroups) + expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) @@ -299,7 +302,7 @@ func TestSetupKey_IsValid(t *testing.T) { func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, - expectedUpdatedAt time.Time, expectedAutoGroups []string) { + expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() if key.Name != expectedName { t.Errorf("expected setup key to have Name %v, got %v", expectedName, key.Name) @@ -329,13 +332,23 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke t.Errorf("expected setup key to have CreatedAt ~ %v, got %v", expectedCreatedAt, key.CreatedAt) } - _, err := uuid.Parse(key.Key) - if err != nil { - t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) + if expectHashedKey { + if !isValidBase64SHA256(key.Key) { + t.Errorf("expected key to be hashed, got %v", key.Key) + } + } else { + _, err := uuid.Parse(key.Key) + if err != nil { + t.Errorf("expected key to be a valid UUID, got %v, %v", key.Key, err) + } } - if key.Id != strconv.Itoa(int(Hash(key.Key))) { - t.Errorf("expected key Id t= %v, got %v", expectedID, key.Id) + if !strings.HasSuffix(key.KeySecret, "****") { + t.Errorf("expected key secret to be secure, got %v", key.Key) + } + + if key.Id != expectedID { + t.Errorf("expected key Id %v, got %v", expectedID, key.Id) } if len(key.AutoGroups) != len(expectedAutoGroups) { @@ -344,13 +357,26 @@ func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke assert.ElementsMatch(t, key.AutoGroups, expectedAutoGroups, "expected key AutoGroups to be equal") } +func isValidBase64SHA256(encodedKey string) bool { + decoded, err := base64.StdEncoding.DecodeString(encodedKey) + if err != nil { + return false + } + + if len(decoded) != sha256.Size { + return false + } + + return true +} + func TestSetupKey_Copy(t *testing.T) { - key := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, - key.UpdatedAt, key.AutoGroups) + key.UpdatedAt, key.AutoGroups, true) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 47395f51109..27238d28e8a 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -469,7 +469,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey - result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) + result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -741,7 +741,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) + result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -973,7 +973,7 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&setupKey, keyQueryCondition, strings.ToUpper(key)) + First(&setupKey, keyQueryCondition, key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") @@ -1232,6 +1232,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) } +func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { + return deleteRecordByID[SetupKey](s.db.WithContext(ctx), LockingStrengthUpdate, keyID, accountID) +} + // getRecords retrieves records from the database based on the account ID. func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { var record []T @@ -1264,3 +1268,21 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// deleteRecordByID deletes a record by its ID and account ID from the database. +func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error { + var record T + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "record not found") + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 000eb1b11b2..b371e231319 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/sha256" + b64 "encoding/base64" "fmt" "math/rand" "net" @@ -71,7 +73,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -81,7 +83,6 @@ func runLargeTest(t *testing.T, store Store) { peer := &nbpeer.Peer{ ID: peerID, Key: peerID, - SetupKey: "", IP: netIP, Name: peerID, DNSLabel: peerID, @@ -133,7 +134,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -215,30 +216,28 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey = GenerateDefaultSetupKey() + setupKey, _ = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - SetupKey: "peerkeysetupkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account2) @@ -297,15 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user @@ -394,13 +392,12 @@ func TestSqlite_SavePeer(t *testing.T) { // save status of non-existing peer peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() err = store.SavePeer(ctx, account.Id, peer) @@ -453,13 +450,12 @@ func TestSqlite_SavePeerStatus(t *testing.T) { // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -720,15 +716,14 @@ func newSqliteStore(t *testing.T) *SqlStore { func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ - Key: "peerkey" + str, - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey" + str, + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } return store.SaveAccount(context.Background(), account) @@ -760,30 +755,28 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey = GenerateDefaultSetupKey() + setupKey, _ = GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - SetupKey: "peerkeysetupkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey2", + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account2) @@ -842,15 +835,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey := GenerateDefaultSetupKey() + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user @@ -921,13 +913,12 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - SetupKey: "peerkeysetupkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -1118,12 +1109,17 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) - assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) + assert.Equal(t, encodedHashedKey, setupKey.Key) + assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } @@ -1138,24 +1134,28 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + plainKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + hashedKey := sha256.Sum256([]byte(plainKey)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 0, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 1, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } @@ -1264,3 +1264,32 @@ func TestSqlite_GetGroupByName(t *testing.T) { require.NoError(t, err) require.Equal(t, "All", group.Name) } + +func Test_DeleteSetupKeySuccessfully(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + + err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) + require.NoError(t, err) + + _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) + require.Error(t, err) +} + +func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nonExistingKeyID := "non-existing-key-id" + + err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) + require.Error(t, err) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 29d185216d8..e9fc8c15ef9 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -114,3 +114,13 @@ func NewGetAccountFromStoreError(err error) error { func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } + +// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key +func NewInvalidKeyIDError() error { + return Errorf(InvalidArgument, "invalid key ID") +} + +// NewUnauthorizedToViewSetupKeysError creates a new Error with Unauthorized type for an issue getting a setup key +func NewUnauthorizedToViewSetupKeysError() error { + return Errorf(Unauthorized, "only users with admin power can view setup keys") +} diff --git a/management/server/store.go b/management/server/store.go index 131fd8aaab6..087c9884763 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -124,6 +124,7 @@ type Store interface { // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + DeleteSetupKey(ctx context.Context, accountID, keyID string) error } type StoreEngine string @@ -241,6 +242,9 @@ func getMigrations(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, + func(db *gorm.DB) error { + return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + }, } } From 01f24907c595ca641035b28159aefe6cf09117d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marco=20Garc=C3=AAs?= Date: Tue, 29 Oct 2024 16:49:41 +0000 Subject: [PATCH 67/81] [client] Fix multiple peer name filtering in netbird status command (#2798) --- client/cmd/status.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client/cmd/status.go b/client/cmd/status.go index ed3daa2b5fd..6db52a67795 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { statusEval := false ipEval := false - nameEval := false + nameEval := true if statusFilter != "" { lowerStatusFilter := strings.ToLower(statusFilter) @@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { if len(prefixNamesFilter) > 0 { for prefixNameFilter := range prefixNamesFilterMap { - if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { - nameEval = true + if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { + nameEval = false break } } + } else { + nameEval = false } return statusEval || ipEval || nameEval From 39c99781cb9837de0c3f96cc20fbed9ed7023c28 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 29 Oct 2024 19:54:38 +0100 Subject: [PATCH 68/81] fix meta is equal slices (#2807) --- management/server/peer/peer.go | 13 +++++++ management/server/peer/peer_test.go | 54 +++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 1ff67da1231..82e0acf3ade 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -4,6 +4,7 @@ import ( "net" "net/netip" "slices" + "sort" "time" ) @@ -107,6 +108,12 @@ type PeerSystemMeta struct { //nolint:revive } func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { + sort.Slice(p.NetworkAddresses, func(i, j int) bool { + return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac + }) + sort.Slice(other.NetworkAddresses, func(i, j int) bool { + return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac + }) equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool { return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP }) @@ -114,6 +121,12 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { return false } + sort.Slice(p.Files, func(i, j int) bool { + return p.Files[i].Path < p.Files[j].Path + }) + sort.Slice(other.Files, func(i, j int) bool { + return other.Files[i].Path < other.Files[j].Path + }) equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool { return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning }) diff --git a/management/server/peer/peer_test.go b/management/server/peer/peer_test.go index 7b94f68c67f..3d3a2e31108 100644 --- a/management/server/peer/peer_test.go +++ b/management/server/peer/peer_test.go @@ -2,6 +2,7 @@ package peer import ( "fmt" + "net/netip" "testing" ) @@ -29,3 +30,56 @@ func BenchmarkFQDN(b *testing.B) { } }) } + +func TestIsEqual(t *testing.T) { + meta1 := PeerSystemMeta{ + NetworkAddresses: []NetworkAddress{{ + NetIP: netip.MustParsePrefix("192.168.1.2/24"), + Mac: "2", + }, + { + NetIP: netip.MustParsePrefix("192.168.1.0/24"), + Mac: "1", + }, + }, + Files: []File{ + { + Path: "/etc/hosts1", + Exist: true, + ProcessIsRunning: true, + }, + { + Path: "/etc/hosts2", + Exist: false, + ProcessIsRunning: false, + }, + }, + } + meta2 := PeerSystemMeta{ + NetworkAddresses: []NetworkAddress{ + { + NetIP: netip.MustParsePrefix("192.168.1.0/24"), + Mac: "1", + }, + { + NetIP: netip.MustParsePrefix("192.168.1.2/24"), + Mac: "2", + }, + }, + Files: []File{ + { + Path: "/etc/hosts2", + Exist: false, + ProcessIsRunning: false, + }, + { + Path: "/etc/hosts1", + Exist: true, + ProcessIsRunning: true, + }, + }, + } + if !meta1.isEqual(meta2) { + t.Error("meta1 should be equal to meta2") + } +} From a0cdb58303807aab0fafc5d64ed1e98015cf20e7 Mon Sep 17 00:00:00 2001 From: Jing Date: Tue, 29 Oct 2024 12:17:40 -0700 Subject: [PATCH 69/81] [client] Fix the broken dependency gvisor.dev/gvisor (#2789) The release was removed which is described at https://github.com/google/gvisor/issues/11085#issuecomment-2438974962. --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index a6b83794dab..7223a446bb1 100644 --- a/go.mod +++ b/go.mod @@ -156,7 +156,7 @@ require ( github.com/go-text/typesetting v0.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/google/btree v1.0.1 // indirect + github.com/google/btree v1.1.2 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.3 // indirect @@ -231,7 +231,7 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect - gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 // indirect + gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 // indirect k8s.io/apimachinery v0.26.2 // indirect ) diff --git a/go.sum b/go.sum index 412542d5eb9..5cd703bc894 100644 --- a/go.sum +++ b/go.sum @@ -297,8 +297,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= -github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= @@ -1238,8 +1238,8 @@ gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4 gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= -gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= From 729bcf2b01b0a50f5fcd326394c43df33c9ab2b2 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:53:23 +0100 Subject: [PATCH 70/81] [management] add metrics to network map diff (#2811) --- .../server/telemetry/updatechannel_metrics.go | 12 ++++++ management/server/updatechannel.go | 15 +++++-- management/server/updatechannel_test.go | 42 +++++++++++-------- 3 files changed, 47 insertions(+), 22 deletions(-) diff --git a/management/server/telemetry/updatechannel_metrics.go b/management/server/telemetry/updatechannel_metrics.go index 2582006e517..fb33b663c62 100644 --- a/management/server/telemetry/updatechannel_metrics.go +++ b/management/server/telemetry/updatechannel_metrics.go @@ -18,6 +18,7 @@ type UpdateChannelMetrics struct { getAllConnectedPeersDurationMicro metric.Int64Histogram getAllConnectedPeers metric.Int64Histogram hasChannelDurationMicro metric.Int64Histogram + networkMapDiffDurationMicro metric.Int64Histogram ctx context.Context } @@ -63,6 +64,11 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh return nil, err } + networkMapDiffDurationMicro, err := meter.Int64Histogram("management.updatechannel.networkmap.diff.duration.micro") + if err != nil { + return nil, err + } + return &UpdateChannelMetrics{ createChannelDurationMicro: createChannelDurationMicro, closeChannelDurationMicro: closeChannelDurationMicro, @@ -72,6 +78,7 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro, getAllConnectedPeers: getAllConnectedPeers, hasChannelDurationMicro: hasChannelDurationMicro, + networkMapDiffDurationMicro: networkMapDiffDurationMicro, ctx: ctx, }, nil } @@ -111,3 +118,8 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) { metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds()) } + +// CountNetworkMapDiffDurationMicro counts the duration of the NetworkMapDiff method +func (metrics *UpdateChannelMetrics) CountNetworkMapDiffDurationMicro(duration time.Duration) { + metrics.networkMapDiffDurationMicro.Record(metrics.ctx, duration.Microseconds()) +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 6fb96c97124..7c73002223b 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -7,11 +7,11 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/management/server/differs" "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/differs" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -208,10 +208,10 @@ func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID p.channelsMux.RUnlock() if lastSentUpdate != nil { - updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update) + updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update, p.metrics) if err != nil { log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) - return false + return true } if !updated { log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) @@ -223,7 +223,9 @@ func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID } // isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. -func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage) (isNew bool, err error) { +func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage, metric telemetry.AppMetrics) (isNew bool, err error) { + startTime := time.Now() + defer func() { if r := recover(); r != nil { log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) @@ -258,6 +260,11 @@ func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSen if err != nil { return false, fmt.Errorf("failed to diff network map: %v", err) } + + if metric != nil { + metric.UpdateChannelMetrics().CountNetworkMapDiffDurationMicro(time.Since(startTime)) + } + return len(changelog) > 0, nil } diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 52b715e9503..b8a0ce45f73 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -7,14 +7,16 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/telemetry" nbroute "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util" - "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -175,8 +177,12 @@ func TestHandlePeerMessageUpdate(t *testing.T) { } for _, tt := range tests { + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatal(err) + } t.Run(tt.name, func(t *testing.T) { - p := NewPeersUpdateManager(nil) + p := NewPeersUpdateManager(metrics) ctx := context.Background() if tt.existingUpdate != nil { @@ -194,7 +200,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage1 := createMockUpdateMessage(t) newUpdateMessage2 := createMockUpdateMessage(t) - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.False(t, message) }) @@ -205,7 +211,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.False(t, message) }) @@ -217,7 +223,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) @@ -230,7 +236,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -249,7 +255,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -259,14 +265,14 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2 := createMockUpdateMessage(t) newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.False(t, message) newUpdateMessage3 := createMockUpdateMessage(t) newUpdateMessage3.Update.Checks = []*proto.Checks{} newUpdateMessage3.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3) + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3, nil) assert.NoError(t, err) assert.True(t, message) @@ -285,7 +291,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { } newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} newUpdateMessage4.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4) + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4, nil) assert.NoError(t, err) assert.True(t, message) @@ -305,7 +311,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { } newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} newUpdateMessage5.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5) + message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -321,7 +327,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { ) newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -333,7 +339,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -345,7 +351,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -364,7 +370,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -376,7 +382,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -388,7 +394,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) @@ -400,7 +406,7 @@ func TestIsNewPeerUpdateMessage(t *testing.T) { newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2) + message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) assert.NoError(t, err) assert.True(t, message) }) From 49a54624f8ed32efde962b355fe84a2b8fe83659 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 30 Oct 2024 17:18:27 +0100 Subject: [PATCH 71/81] Create funding.json (#2813) --- funding.json | 126 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 funding.json diff --git a/funding.json b/funding.json new file mode 100644 index 00000000000..6b509a9921b --- /dev/null +++ b/funding.json @@ -0,0 +1,126 @@ +{ + "version": "v1.0.0", + "entity": { + "type": "organisation", + "role": "owner", + "name": "NetBird GmbH", + "email": "hello@netbird.io", + "phone": "", + "description": "NetBird GmbH is a Berlin-based software company specializing in the development of open-source network security solutions. Network security is utterly complex and expensive, accessible only to companies with multi-million dollar IT budgets. In contrast, there are millions of companies left behind. Our mission is to create an advanced network and cybersecurity platform that is both easy-to-use and affordable for teams of all sizes and budgets. By leveraging the open-source strategy and technological advancements, NetBird aims to set the industry standard for connecting and securing IT infrastructure.", + "webpageUrl": { + "url": "https://github.com/netbirdio" + } + }, + "projects": [ + { + "guid": "netbird", + "name": "NetBird", + "description": "NetBird is a configuration-free peer-to-peer private network and a centralized access control system combined in a single open-source platform. It makes it easy to create secure WireGuard-based private networks for your organization or home.", + "webpageUrl": { + "url": "https://github.com/netbirdio/netbird" + }, + "repositoryUrl": { + "url": "https://github.com/netbirdio/netbird" + }, + "licenses": [ + "BSD-3" + ], + "tags": [ + "network-security", + "vpn", + "developer-tools", + "ztna", + "zero-trust", + "remote-access", + "wireguard", + "peer-to-peer", + "private-networking", + "software-defined-networking" + ] + } + ], + "funding": { + "channels": [ + { + "guid": "github-sponsors", + "type": "payment-provider", + "address": "https://github.com/sponsors/netbirdio", + "description": "" + }, + { + "guid": "bank-transfer", + "type": "bank", + "address": "", + "description": "Contact us at hello@netbird.io for bank transfer details." + } + ], + "plans": [ + { + "guid": "support-yearly", + "status": "active", + "name": "Support Open Source Development and Maintenance - Yearly", + "description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.", + "amount": 100000, + "currency": "USD", + "frequency": "yearly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "support-one-time-year", + "status": "active", + "name": "Support Open Source Development and Maintenance - One Year", + "description": "This will help us partially cover the yearly cost of maintaining the open-source NetBird project.", + "amount": 100000, + "currency": "USD", + "frequency": "one-time", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "support-one-time-monthly", + "status": "active", + "name": "Support Open Source Development and Maintenance - Monthly", + "description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.", + "amount": 10000, + "currency": "USD", + "frequency": "monthly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "support-monthly", + "status": "active", + "name": "Support Open Source Development and Maintenance - One Month", + "description": "This will help us partially cover the monthly cost of maintaining the open-source NetBird project.", + "amount": 10000, + "currency": "USD", + "frequency": "monthly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + }, + { + "guid": "goodwill", + "status": "active", + "name": "Goodwill Plan", + "description": "Pay anything you wish to show your goodwill for the project.", + "amount": 0, + "currency": "USD", + "frequency": "monthly", + "channels": [ + "github-sponsors", + "bank-transfer" + ] + } + ], + "history": null + } +} From ec5095ba6b0c5f3c43a652bf5afeda07fcaffb55 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 30 Oct 2024 17:25:02 +0100 Subject: [PATCH 72/81] Create FUNDING.yml (#2814) --- .github/FUNDING.yml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .github/FUNDING.yml diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 00000000000..c3d322163e2 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [netbirdio] From 4c758c6e526a5ba8ee0088263d5883a2c817a190 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 31 Oct 2024 19:24:15 +0100 Subject: [PATCH 73/81] [management] remove network map diff calculations (#2820) --- go.mod | 3 - go.sum | 6 - management/server/differs/netip.go | 82 --- management/server/dns_test.go | 73 ++- management/server/group_test.go | 28 +- management/server/nameserver_test.go | 30 -- management/server/network.go | 4 +- management/server/peer/peer.go | 18 +- management/server/policy.go | 4 +- management/server/policy_test.go | 62 +-- management/server/posture_checks_test.go | 65 +-- management/server/route_test.go | 20 - .../server/telemetry/updatechannel_metrics.go | 12 - management/server/updatechannel.go | 110 +--- management/server/updatechannel_test.go | 482 ------------------ 15 files changed, 82 insertions(+), 917 deletions(-) delete mode 100644 management/server/differs/netip.go diff --git a/go.mod b/go.mod index 7223a446bb1..571b41abf19 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,6 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 - github.com/r3labs/diff/v3 v3.0.1 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 @@ -211,8 +210,6 @@ require ( github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect - github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect - github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect diff --git a/go.sum b/go.sum index 5cd703bc894..217d27e0adf 100644 --- a/go.sum +++ b/go.sum @@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= -github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg= -github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= @@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= -github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= -github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= -github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/differs/netip.go b/management/server/differs/netip.go deleted file mode 100644 index de4aa334c17..00000000000 --- a/management/server/differs/netip.go +++ /dev/null @@ -1,82 +0,0 @@ -package differs - -import ( - "fmt" - "net/netip" - "reflect" - - "github.com/r3labs/diff/v3" -) - -// NetIPAddr is a custom differ for netip.Addr -type NetIPAddr struct { - DiffFunc func(path []string, a, b reflect.Value, p interface{}) error -} - -func (differ NetIPAddr) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(netip.Addr{})) -} - -func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return nil - } - - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return nil - } - - fromAddr, ok1 := a.Interface().(netip.Addr) - toAddr, ok2 := b.Interface().(netip.Addr) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for netip.Addr") - } - - if fromAddr.String() != toAddr.String() { - cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String()) - } - - return nil -} - -func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { - differ.DiffFunc = dfunc //nolint -} - -// NetIPPrefix is a custom differ for netip.Prefix -type NetIPPrefix struct { - DiffFunc func(path []string, a, b reflect.Value, p interface{}) error -} - -func (differ NetIPPrefix) Match(a, b reflect.Value) bool { - return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{})) -} - -func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error { - if a.Kind() == reflect.Invalid { - cl.Add(diff.CREATE, path, nil, b.Interface()) - return nil - } - if b.Kind() == reflect.Invalid { - cl.Add(diff.DELETE, path, a.Interface(), nil) - return nil - } - - fromPrefix, ok1 := a.Interface().(netip.Prefix) - toPrefix, ok2 := b.Interface().(netip.Prefix) - if !ok1 || !ok2 { - return fmt.Errorf("invalid type for netip.Addr") - } - - if fromPrefix.String() != toPrefix.String() { - cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String()) - } - - return nil -} - -func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) { - differ.DiffFunc = dfunc //nolint -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index c675fc12c84..8a66da96c0f 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -521,35 +522,56 @@ func TestDNSAccountPeersUpdate(t *testing.T) { } }) - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + // Creating DNS settings with groups that have no peers should not update account peers or send peer update + t.Run("creating dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupB"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } }) - assert.NoError(t, err) - _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ - IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, - }}, - []string{"groupA"}, - true, []string{}, true, userID, false, - ) - assert.NoError(t, err) + // Creating DNS settings with groups that have peers should update account peers and send peer update + t.Run("creating dns setting with used groups", func(t *testing.T) { + err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) - // Saving DNS settings with groups that have peers should update account peers and send peer update - t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ - DisabledManagementGroups: []string{"groupA", "groupB"}, - }) + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"groupA"}, + true, []string{}, true, userID, false, + ) assert.NoError(t, err) select { @@ -559,12 +581,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged DNS settings with used groups should update account peers and not send peer update - // since there is no change in the network map - t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + // Saving DNS settings with groups that have peers should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -576,7 +597,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { select { case <-done: case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") + t.Error("timeout waiting for peerShouldReceiveUpdate") } }) diff --git a/management/server/group_test.go b/management/server/group_test.go index 1e59b74ef5b..89184e81927 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -8,12 +8,13 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -536,29 +537,6 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } }) - // Saving an unchanged group should trigger account peers update and not send peer update - // since there is no change in the network map - t.Run("saving unchanged group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID}, - }) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // adding peer to a used group should update account peers and send peer update t.Run("adding peer to linked group", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 96637cd39a0..846dbf02370 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1065,36 +1065,6 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { } }) - // saving unchanged nameserver group should update account peers and not send peer update - t.Run("saving unchanged nameserver group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - newNameServerGroupB.NameServers = []nbdns.NameServer{ - { - IP: netip.MustParseAddr("1.1.1.2"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - { - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }, - } - err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting a nameserver group should update account peers and send peer update t.Run("deleting nameserver group", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/network.go b/management/server/network.go index 8fb6a8b3c12..a5b188b4610 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -41,9 +41,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 `diff:"-"` + Serial uint64 - mu sync.Mutex `json:"-" gorm:"-" diff:"-"` + mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 82e0acf3ade..34d7918446b 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -20,33 +20,33 @@ type Peer struct { // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` // The user ID that registered the peer - UserID string `diff:"-"` + UserID string // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool `diff:"-"` + LoginExpirationEnabled bool - InactivityExpirationEnabled bool `diff:"-"` + InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation - LastLogin time.Time `diff:"-"` + LastLogin time.Time // CreatedAt records the time the peer was created - CreatedAt time.Time `diff:"-"` + CreatedAt time.Time // Indicate ephemeral peer attribute - Ephemeral bool `diff:"-"` + Ephemeral bool // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` + Location Location `gorm:"embedded;embeddedPrefix:location_"` } type PeerStatus struct { //nolint:revive diff --git a/management/server/policy.go b/management/server/policy.go index 05554243032..43a925f8850 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - am.updateAccountPeers(ctx, account) + if anyGroupHasPeers(account, policy.ruleGroups()) { + am.updateAccountPeers(ctx, account) + } return nil } diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 5b1411702b2..e7f0f9cd2f1 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -854,16 +854,11 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) - t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) - }) - // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { policy := Policy{ @@ -883,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() @@ -918,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -953,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg2) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -987,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1021,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1056,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() @@ -1090,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1104,46 +1099,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged policy should trigger account peers update but not send peer update - t.Run("saving unchanged policy", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg1) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { policyID := "policy-source-destination-peers" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1164,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg2) + peerShouldReceiveUpdate(t, updMsg) close(done) }() @@ -1180,10 +1142,10 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2 + policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg1) + peerShouldNotReceiveUpdate(t, updMsg) close(done) }() diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 7d31956f955..c63538b9d52 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" ) @@ -264,25 +265,6 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - // Saving unchanged posture check should not trigger account peers update and not send peer update - // since there is no change in the network map - t.Run("saving unchanged posture check", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Removing posture check from policy should trigger account peers update and send peer update t.Run("removing posture check from policy", func(t *testing.T) { done := make(chan struct{}) @@ -412,50 +394,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - // Updating linked posture check to policy where source has peers but destination does not, - // should not trigger account peers update or send peer update - t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupB"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - SourcePostureChecks: []string{postureCheck.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) - assert.NoError(t, err) - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - postureCheck.Checks = posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.29.0", - }, - } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) - assert.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update - t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) { + t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { policy = Policy{ ID: "policyB", Enabled: true, diff --git a/management/server/route_test.go b/management/server/route_test.go index a4b320c7ee2..4893e19b9f3 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1938,26 +1938,6 @@ func TestRouteAccountPeersUpdate(t *testing.T) { } }) - // Updating unchanged route should update account peers and not send peer update - t.Run("updating unchanged route", func(t *testing.T) { - baseRoute.Groups = []string{routeGroup1, routeGroup2} - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - // Deleting the route should update account peers and send peer update t.Run("deleting route", func(t *testing.T) { done := make(chan struct{}) diff --git a/management/server/telemetry/updatechannel_metrics.go b/management/server/telemetry/updatechannel_metrics.go index fb33b663c62..2582006e517 100644 --- a/management/server/telemetry/updatechannel_metrics.go +++ b/management/server/telemetry/updatechannel_metrics.go @@ -18,7 +18,6 @@ type UpdateChannelMetrics struct { getAllConnectedPeersDurationMicro metric.Int64Histogram getAllConnectedPeers metric.Int64Histogram hasChannelDurationMicro metric.Int64Histogram - networkMapDiffDurationMicro metric.Int64Histogram ctx context.Context } @@ -64,11 +63,6 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh return nil, err } - networkMapDiffDurationMicro, err := meter.Int64Histogram("management.updatechannel.networkmap.diff.duration.micro") - if err != nil { - return nil, err - } - return &UpdateChannelMetrics{ createChannelDurationMicro: createChannelDurationMicro, closeChannelDurationMicro: closeChannelDurationMicro, @@ -78,7 +72,6 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro, getAllConnectedPeers: getAllConnectedPeers, hasChannelDurationMicro: hasChannelDurationMicro, - networkMapDiffDurationMicro: networkMapDiffDurationMicro, ctx: ctx, }, nil } @@ -118,8 +111,3 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) { metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds()) } - -// CountNetworkMapDiffDurationMicro counts the duration of the NetworkMapDiff method -func (metrics *UpdateChannelMetrics) CountNetworkMapDiffDurationMicro(duration time.Duration) { - metrics.networkMapDiffDurationMicro.Record(metrics.ctx, duration.Microseconds()) -} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 7c73002223b..59b6fd09492 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,16 +2,12 @@ package server import ( "context" - "fmt" - "runtime/debug" "sync" "time" - "github.com/r3labs/diff/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" - "github.com/netbirdio/netbird/management/server/differs" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -25,8 +21,6 @@ type UpdateMessage struct { type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage - // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. - peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics @@ -36,10 +30,9 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - peerUpdateMessage: make(map[string]*UpdateMessage), - channelsMux: &sync.RWMutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -48,15 +41,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool - // skip sending sync update to the peer if there is no change in update message, - // it will not check on turn credential refresh as we do not send network map or client posture checks - if update.NetworkMap != nil { - updated := p.handlePeerMessageUpdate(ctx, peerID, update) - if !updated { - return - } - } - p.channelsMux.Lock() defer func() { @@ -66,16 +50,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() - if update.NetworkMap != nil { - lastSentUpdate := p.peerUpdateMessage[peerID] - if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() { - log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", - peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) - return - } - p.peerUpdateMessage[peerID] = update - } - if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -108,7 +82,6 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) - delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -123,7 +96,6 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) - delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -200,79 +172,3 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } - -// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. -func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { - p.channelsMux.RLock() - lastSentUpdate := p.peerUpdateMessage[peerID] - p.channelsMux.RUnlock() - - if lastSentUpdate != nil { - updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update, p.metrics) - if err != nil { - log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) - return true - } - if !updated { - log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) - return false - } - } - - return true -} - -// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. -func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage, metric telemetry.AppMetrics) (isNew bool, err error) { - startTime := time.Now() - - defer func() { - if r := recover(); r != nil { - log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack()) - isNew, err = true, nil - } - }() - - if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() { - return false, nil - } - - differ, err := diff.NewDiffer( - diff.CustomValueDiffers(&differs.NetIPAddr{}), - diff.CustomValueDiffers(&differs.NetIPPrefix{}), - ) - if err != nil { - return false, fmt.Errorf("failed to create differ: %v", err) - } - - lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks) - currFiles := getChecksFiles(currUpdateToSend.Update.Checks) - - changelog, err := differ.Diff(lastSentFiles, currFiles) - if err != nil { - return false, fmt.Errorf("failed to diff checks: %v", err) - } - if len(changelog) > 0 { - return true, nil - } - - changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) - if err != nil { - return false, fmt.Errorf("failed to diff network map: %v", err) - } - - if metric != nil { - metric.UpdateChannelMetrics().CountNetworkMapDiffDurationMicro(time.Since(startTime)) - } - - return len(changelog) > 0, nil -} - -// getChecksFiles returns a list of files from the given checks. -func getChecksFiles(checks []*proto.Checks) []string { - files := make([]string, 0, len(checks)) - for _, check := range checks { - files = append(files, check.GetFiles()...) - } - return files -} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index b8a0ce45f73..69f5b895c45 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -2,21 +2,10 @@ package server import ( "context" - "net" - "net/netip" "testing" "time" - "github.com/stretchr/testify/assert" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/telemetry" - nbroute "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/util" ) // var peersUpdater *PeersUpdateManager @@ -88,474 +77,3 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } - -func TestHandlePeerMessageUpdate(t *testing.T) { - tests := []struct { - name string - peerID string - existingUpdate *UpdateMessage - newUpdate *UpdateMessage - expectedResult bool - }{ - { - name: "update message with turn credentials update", - peerID: "peer", - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - WiretrusteeConfig: &proto.WiretrusteeConfig{}, - }, - }, - expectedResult: true, - }, - { - name: "update message for peer without existing update", - peerID: "peer1", - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - expectedResult: true, - }, - { - name: "update message with no changes in update", - peerID: "peer2", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - expectedResult: false, - }, - { - name: "update message with changes in checks", - peerID: "peer3", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 2}, - Checks: []*proto.Checks{ - { - Files: []string{"/usr/bin/netbird"}, - }, - }, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - expectedResult: true, - }, - { - name: "update message with lower serial number", - peerID: "peer4", - existingUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 2}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, - }, - newUpdate: &UpdateMessage{ - Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{Serial: 1}, - }, - NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, - }, - expectedResult: false, - }, - } - - for _, tt := range tests { - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) - if err != nil { - t.Fatal(err) - } - t.Run(tt.name, func(t *testing.T) { - p := NewPeersUpdateManager(metrics) - ctx := context.Background() - - if tt.existingUpdate != nil { - p.peerUpdateMessage[tt.peerID] = tt.existingUpdate - } - - result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) - assert.Equal(t, tt.expectedResult, result) - }) - } -} - -func TestIsNewPeerUpdateMessage(t *testing.T) { - t.Run("Unchanged value", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.False(t, message) - }) - - t.Run("Unchanged value with serial incremented", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.False(t, message) - }) - - t.Run("Updating routes network", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - - }) - - t.Run("Updating routes groups", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"} - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating network map peers", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newPeer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.4"), - SSHEnabled: true, - Key: "peer4-key", - DNSLabel: "peer4", - SSHKey: "peer4-ssh-key", - } - newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating process check", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - - newUpdateMessage2 := createMockUpdateMessage(t) - newUpdateMessage2.Update.NetworkMap.Serial++ - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.False(t, message) - - newUpdateMessage3 := createMockUpdateMessage(t) - newUpdateMessage3.Update.Checks = []*proto.Checks{} - newUpdateMessage3.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3, nil) - assert.NoError(t, err) - assert.True(t, message) - - newUpdateMessage4 := createMockUpdateMessage(t) - check := &posture.Checks{ - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/local/netbird", - MacPath: "/usr/bin/netbird", - }, - }, - }, - }, - } - newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)} - newUpdateMessage4.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4, nil) - assert.NoError(t, err) - assert.True(t, message) - - newUpdateMessage5 := createMockUpdateMessage(t) - check = &posture.Checks{ - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/bin/netbird", - WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", - MacPath: "/usr/local/netbird", - }, - }, - }, - }, - } - newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)} - newUpdateMessage5.Update.NetworkMap.Serial++ - message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating DNS configuration", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newDomain := "newexample.com" - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append( - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains, - newDomain, - ) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating peer IP", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating firewall rule", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443" - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Add new firewall rule", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newRule := &FirewallRule{ - PeerIP: "192.168.1.3", - Direction: firewallRuleDirectionOUT, - Action: string(PolicyTrafficActionDrop), - Protocol: string(PolicyRuleProtocolUDP), - Port: "53", - } - newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Removing nameserver", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0) - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating name server IP", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4") - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - - t.Run("Updating custom DNS zone", func(t *testing.T) { - newUpdateMessage1 := createMockUpdateMessage(t) - newUpdateMessage2 := createMockUpdateMessage(t) - - newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2" - newUpdateMessage2.Update.NetworkMap.Serial++ - - message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil) - assert.NoError(t, err) - assert.True(t, message) - }) - -} - -func createMockUpdateMessage(t *testing.T) *UpdateMessage { - t.Helper() - - _, ipNet, err := net.ParseCIDR("192.168.1.0/24") - if err != nil { - t.Fatal(err) - } - domainList, err := domain.FromStringList([]string{"example.com"}) - if err != nil { - t.Fatal(err) - } - - config := &Config{ - Signal: &Host{ - Proto: "https", - URI: "signal.uri", - Username: "", - Password: "", - }, - Stuns: []*Host{{URI: "stun.uri", Proto: UDP}}, - TURNConfig: &TURNConfig{ - Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}}, - }, - } - peer := &nbpeer.Peer{ - IP: net.ParseIP("192.168.1.1"), - SSHEnabled: true, - Key: "peer-key", - DNSLabel: "peer1", - SSHKey: "peer1-ssh-key", - } - - secretManager := NewTimeBasedAuthSecretsManager( - NewPeersUpdateManager(nil), - &TURNConfig{ - TimeBasedCredentials: false, - CredentialsTTL: util.Duration{ - Duration: defaultDuration, - }, - Secret: "secret", - Turns: []*Host{TurnTestHost}, - }, - &Relay{ - Addresses: []string{"localhost:0"}, - CredentialsTTL: util.Duration{Duration: time.Hour}, - Secret: "secret", - }, - ) - - networkMap := &NetworkMap{ - Network: &Network{Net: *ipNet, Serial: 1000}, - Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, - OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, - Routes: []*nbroute.Route{ - { - ID: "route1", - Network: netip.MustParsePrefix("10.0.0.0/24"), - KeepRoute: true, - NetID: "route1", - Peer: "peer1", - NetworkType: 1, - Masquerade: true, - Metric: 9999, - Enabled: true, - Groups: []string{"test1", "test2"}, - }, - { - ID: "route2", - Domains: domainList, - KeepRoute: true, - NetID: "route2", - Peer: "peer1", - NetworkType: 1, - Masquerade: true, - Metric: 9999, - Enabled: true, - Groups: []string{"test1", "test2"}, - }, - }, - DNSConfig: nbdns.Config{ - ServiceEnable: true, - NameServerGroups: []*nbdns.NameServerGroup{ - { - NameServers: []nbdns.NameServer{{ - IP: netip.MustParseAddr("8.8.8.8"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }}, - Primary: true, - Domains: []string{"example.com"}, - Enabled: true, - SearchDomainsEnabled: true, - }, - { - ID: "ns1", - NameServers: []nbdns.NameServer{{ - IP: netip.MustParseAddr("1.1.1.1"), - NSType: nbdns.UDPNameServerType, - Port: nbdns.DefaultDNSPort, - }}, - Groups: []string{"group1"}, - Primary: true, - Domains: []string{"example.com"}, - Enabled: true, - SearchDomainsEnabled: true, - }, - }, - CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, - }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, - }, - } - dnsName := "example.com" - checks := []*posture.Checks{ - { - Checks: posture.ChecksDefinition{ - ProcessCheck: &posture.ProcessCheck{ - Processes: []posture.Process{ - { - LinuxPath: "/usr/bin/netbird", - WindowsPath: "C:\\Program Files\\netbird\\netbird.exe", - MacPath: "/usr/bin/netbird", - }, - }, - }, - }, - }, - } - dnsCache := &DNSConfigCache{} - - turnToken, err := secretManager.GenerateTurnToken() - if err != nil { - t.Fatal(err) - } - - relayToken, err := secretManager.GenerateRelayToken() - if err != nil { - t.Fatal(err) - } - - return &UpdateMessage{ - Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache), - NetworkMap: networkMap, - } -} From ad4f0a6fdfae53a569cf20b78610567a78f83f02 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 31 Oct 2024 23:18:35 +0100 Subject: [PATCH 74/81] [client] Nil check on ICE remote conn (#2806) --- client/internal/peer/conn.go | 5 +++++ client/internal/peer/nilcheck.go | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 client/internal/peer/nilcheck.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 56b772759a2..84a8c221fa6 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -309,6 +309,11 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon return } + if remoteConnNil(conn.log, iceConnInfo.RemoteConn) { + conn.log.Errorf("remote ICE connection is nil") + return + } + conn.log.Debugf("ICE connection is ready") if conn.currentConnPriority > priority { diff --git a/client/internal/peer/nilcheck.go b/client/internal/peer/nilcheck.go new file mode 100644 index 00000000000..058fe9a2697 --- /dev/null +++ b/client/internal/peer/nilcheck.go @@ -0,0 +1,21 @@ +package peer + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +func remoteConnNil(log *log.Entry, conn net.Conn) bool { + if conn == nil { + log.Errorf("ice conn is nil") + return true + } + + if conn.RemoteAddr() == nil { + log.Errorf("ICE remote address is nil") + return true + } + + return false +} From 9812de853bac5c61ac38f6497014e7db1e4e290e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 1 Nov 2024 00:33:25 +0100 Subject: [PATCH 75/81] Allocate new buffer for every package (#2823) --- client/iface/wgproxy/bind/proxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index e986d6d7b07..e0883715a99 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -104,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { } }() - buf := make([]byte, 1500) for { + buf := make([]byte, 1500) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { From bac95ace1820d89f5642a528f877aad4e96539d9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:58:39 +0100 Subject: [PATCH 76/81] [management] Add DB access duration to logs for context cancel (#2781) --- management/server/sql_store.go | 202 ++++++++++++++++++++++++++++-- management/server/status/error.go | 6 + 2 files changed, 198 insertions(+), 10 deletions(-) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 27238d28e8a..b1b8330ba3b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -292,6 +292,8 @@ func (s *SqlStore) GetInstallationID() string { } func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { + startTime := time.Now() + // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID @@ -317,6 +319,9 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. }) if err != nil { + if errors.Is(err, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return err } @@ -324,6 +329,8 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { + startTime := time.Now() + accountCopy := Account{ Domain: domain, DomainCategory: category, @@ -336,6 +343,9 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return result.Error } @@ -347,6 +357,8 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID } func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { + startTime := time.Now() + var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -359,6 +371,9 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return result.Error } @@ -370,6 +385,8 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe } func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { + startTime := time.Now() + // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, @@ -381,6 +398,9 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P Updates(peerCopy) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return result.Error } @@ -394,6 +414,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P // SaveUsers saves the given list of users to the database. // It updates existing users if a conflict occurs. func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { + startTime := time.Now() + usersToSave := make([]User, 0, len(users)) for _, user := range users { user.AccountID = accountID @@ -403,15 +425,28 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { } usersToSave = append(usersToSave, *user) } - return s.db.Session(&gorm.Session{FullSaveAssociations: true}). + err := s.db.Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.OnConflict{UpdateAll: true}). Create(&usersToSave).Error + if err != nil { + if errors.Is(err, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } + return status.Errorf(status.Internal, "failed to save users to store: %v", err) + } + + return nil } // SaveUser saves the given user to the database. func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { + startTime := time.Now() + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) } return nil @@ -419,12 +454,17 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u // SaveGroups saves the given list of groups to the database. func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { + startTime := time.Now() + if len(groups) == 0 { return nil } result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error) } return nil @@ -451,6 +491,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) } func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + startTime := time.Now() + var accountID string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", @@ -460,6 +502,9 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } @@ -468,12 +513,17 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { + startTime := time.Now() + var key SetupKey result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewSetupKeyNotFoundError(result.Error) } @@ -485,12 +535,17 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { + startTime := time.Now() + var token PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } @@ -499,12 +554,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri } func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { + startTime := time.Now() + var token PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -528,6 +588,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { + startTime := time.Now() + var user User result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).First(&user, idQueryCondition, userID) @@ -535,6 +597,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetUserFromStoreError() } @@ -542,12 +607,17 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { + startTime := time.Now() + var users []*User result := s.db.Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting users from store") } @@ -556,12 +626,17 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us } func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + startTime := time.Now() + var groups []*nbgroup.Group result := s.db.Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting groups from store") } @@ -661,12 +736,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { + startTime := time.Now() + var user User result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -678,12 +758,17 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun } func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { + startTime := time.Now() + var peer nbpeer.Peer result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -695,13 +780,17 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco } func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { - var peer nbpeer.Peer + startTime := time.Now() + var peer nbpeer.Peer result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -713,6 +802,8 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( } func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { + startTime := time.Now() + var peer nbpeer.Peer var accountID string result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) @@ -720,6 +811,9 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -727,12 +821,17 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { + startTime := time.Now() + var accountID string result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -740,12 +839,17 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { } func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { + startTime := time.Now() + var accountID string result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } + if errors.Is(result.Error, context.Canceled) { + return "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", status.NewSetupKeyNotFoundError(result.Error) } @@ -757,6 +861,8 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) } func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + startTime := time.Now() + var ipJSONStrings []string // Fetch the IP addresses as JSON strings @@ -767,6 +873,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error) } @@ -784,8 +893,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength } func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { - var labels []string + startTime := time.Now() + var labels []string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). Where("account_id = ?", accountID). Pluck("dns_label", &labels) @@ -794,6 +904,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "no peers found for the account") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error) } @@ -802,24 +915,33 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - var accountNetwork AccountNetwork + startTime := time.Now() + var accountNetwork AccountNetwork if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } + if errors.Is(err, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { + startTime := time.Now() + var peer nbpeer.Peer result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error) } @@ -827,11 +949,16 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { + startTime := time.Now() + var accountSettings AccountSettings if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } + if errors.Is(err, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err) } return accountSettings.Settings, nil @@ -839,13 +966,17 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - var user User + startTime := time.Now() + var user User result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) } + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.NewGetUserFromStoreError() } user.LastLogin = lastLogin @@ -854,6 +985,8 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { + startTime := time.Now() + definitionJSON, err := json.Marshal(checks) if err != nil { return nil, err @@ -862,6 +995,9 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p var postureCheck posture.Checks err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error if err != nil { + if errors.Is(err, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, err } @@ -971,6 +1107,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + startTime := time.Now() + var setupKey SetupKey result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, keyQueryCondition, key) @@ -978,12 +1116,17 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "setup key not found") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.NewSetupKeyNotFoundError(result.Error) } return &setupKey, nil } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + startTime := time.Now() + result := s.db.WithContext(ctx).Model(&SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ @@ -992,6 +1135,9 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string }) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) } @@ -1003,13 +1149,17 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string } func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group + startTime := time.Now() + var group nbgroup.Group result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") } + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error) } @@ -1022,6 +1172,9 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) if err := s.db.Save(&group).Error; err != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } @@ -1029,13 +1182,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer } func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group + startTime := time.Now() + var group nbgroup.Group result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group not found for account") } + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue finding group: %s", result.Error) } @@ -1048,6 +1205,9 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) if err := s.db.Save(&group).Error; err != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1060,7 +1220,12 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + startTime := time.Now() + if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + if errors.Is(err, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1068,8 +1233,13 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + startTime := time.Now() + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { + if errors.Is(result.Error, context.Canceled) { + return status.NewStoreContextCanceledError(time.Since(startTime)) + } return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } return nil @@ -1100,14 +1270,18 @@ func (s *SqlStore) GetDB() *gorm.DB { } func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - var accountDNSSettings AccountDNSSettings + startTime := time.Now() + var accountDNSSettings AccountDNSSettings result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "dns settings not found") } + if errors.Is(result.Error, context.Canceled) { + return nil, status.NewStoreContextCanceledError(time.Since(startTime)) + } return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) } return &accountDNSSettings.DNSSettings, nil @@ -1115,14 +1289,18 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { - var accountID string + startTime := time.Now() + var accountID string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return false, nil } + if errors.Is(result.Error, context.Canceled) { + return false, status.NewStoreContextCanceledError(time.Since(startTime)) + } return false, result.Error } @@ -1131,14 +1309,18 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - var account Account + startTime := time.Now() + var account Account result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", "", status.Errorf(status.NotFound, "account not found") } + if errors.Is(result.Error, context.Canceled) { + return "", "", status.NewStoreContextCanceledError(time.Since(startTime)) + } return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) } diff --git a/management/server/status/error.go b/management/server/status/error.go index e9fc8c15ef9..a145edf8002 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,6 +3,7 @@ package status import ( "errors" "fmt" + "time" ) const ( @@ -115,6 +116,11 @@ func NewGetUserFromStoreError() error { return Errorf(Internal, "issue getting user from store") } +// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context +func NewStoreContextCanceledError(duration time.Duration) error { + return Errorf(Internal, "store access: context canceled after %v", duration) +} + // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") From 0eb99c266affccaa03d9c363862655edd8798b22 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 1 Nov 2024 12:33:29 +0100 Subject: [PATCH 77/81] Fix unused servers cleanup (#2826) The cleanup loop did not manage those situations well when a connection failed or the connection success but the code did not add a peer connection to it yet. - in the cleanup loop check if a connection failed to a server - after adding a foreign server connection force to keep it a minimum 5 sec --- relay/client/manager.go | 18 +++++++++++++++++- relay/client/manager_test.go | 5 +++-- relay/client/picker_test.go | 3 +-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/relay/client/manager.go b/relay/client/manager.go index 3981415fcd4..b14a7701bfb 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -16,6 +16,7 @@ import ( var ( relayCleanupInterval = 60 * time.Second + keepUnusedServerTime = 5 * time.Second ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -27,10 +28,13 @@ type RelayTrack struct { sync.RWMutex relayClient *Client err error + created time.Time } func NewRelayTrack() *RelayTrack { - return &RelayTrack{} + return &RelayTrack{ + created: time.Now(), + } } type OnServerCloseListener func() @@ -302,6 +306,18 @@ func (m *Manager) cleanUpUnusedRelays() { for addr, rt := range m.relayClients { rt.Lock() + // if the connection failed to the server the relay client will be nil + // but the instance will be kept in the relayClients until the next locking + if rt.err != nil { + rt.Unlock() + continue + } + + if time.Since(rt.created) <= keepUnusedServerTime { + rt.Unlock() + continue + } + if rt.relayClient.HasConns() { rt.Unlock() continue diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index e9cc2c58154..bfc342f25f7 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -288,8 +288,9 @@ func TestForeginAutoClose(t *testing.T) { t.Fatalf("failed to close connection: %s", err) } - t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second) - time.Sleep(relayCleanupInterval + 1*time.Second) + timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second + t.Logf("waiting for relay cleanup: %s", timeout) + time.Sleep(timeout) if len(mgr.relayClients) != 0 { t.Errorf("expected 0, got %d", len(mgr.relayClients)) } diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index eb14581e067..4800e05ba29 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "testing" - "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { @@ -13,7 +12,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) { PeerID: "test", } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { From 5f06b202c364dd66e57b8a58f178a8647a6ddfce Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:08:22 +0100 Subject: [PATCH 78/81] [client] Log windows panics (#2829) --- client/server/panic_generic.go | 7 +++ client/server/panic_windows.go | 83 ++++++++++++++++++++++++++++++++++ client/server/server.go | 4 ++ 3 files changed, 94 insertions(+) create mode 100644 client/server/panic_generic.go create mode 100644 client/server/panic_windows.go diff --git a/client/server/panic_generic.go b/client/server/panic_generic.go new file mode 100644 index 00000000000..f027b954b34 --- /dev/null +++ b/client/server/panic_generic.go @@ -0,0 +1,7 @@ +//go:build !windows + +package server + +func handlePanicLog() error { + return nil +} diff --git a/client/server/panic_windows.go b/client/server/panic_windows.go new file mode 100644 index 00000000000..1d4ba4b756f --- /dev/null +++ b/client/server/panic_windows.go @@ -0,0 +1,83 @@ +package server + +import ( + "fmt" + "os" + "path/filepath" + "syscall" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +const ( + windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG" + // STD_ERROR_HANDLE ((DWORD)-12) = 4294967284 + stdErrorHandle = ^uintptr(11) +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + + // https://learn.microsoft.com/en-us/windows/console/setstdhandle + setStdHandleFn = kernel32.NewProc("SetStdHandle") +) + +func handlePanicLog() error { + logPath := os.Getenv(windowsPanicLogEnvVar) + if logPath == "" { + return nil + } + + // Ensure the directory exists + logDir := filepath.Dir(logPath) + if err := os.MkdirAll(logDir, 0750); err != nil { + return fmt.Errorf("create panic log directory: %w", err) + } + if err := util.EnforcePermission(logPath); err != nil { + return fmt.Errorf("enforce permission on panic log file: %w", err) + } + + // Open log file with append mode + f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + return fmt.Errorf("open panic log file: %w", err) + } + + // Redirect stderr to the file + if err = redirectStderr(f); err != nil { + if closeErr := f.Close(); closeErr != nil { + log.Warnf("failed to close file after redirect error: %v", closeErr) + } + return fmt.Errorf("redirect stderr: %w", err) + } + + log.Infof("successfully configured panic logging to: %s", logPath) + return nil +} + +// redirectStderr redirects stderr to the provided file +func redirectStderr(f *os.File) error { + // Get the current process's stderr handle + if err := setStdHandle(f); err != nil { + return fmt.Errorf("failed to set stderr handle: %w", err) + } + + // Also set os.Stderr for Go's standard library + os.Stderr = f + + return nil +} + +func setStdHandle(f *os.File) error { + handle := f.Fd() + r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle) + if r0 == 0 { + if e1 != nil { + return e1 + } + return syscall.EINVAL + } + return nil +} diff --git a/client/server/server.go b/client/server/server.go index a0332208194..4d921851f94 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -97,6 +97,10 @@ func (s *Server) Start() error { defer s.mutex.Unlock() state := internal.CtxGetState(s.rootCtx) + if err := handlePanicLog(); err != nil { + log.Warnf("failed to redirect stderr: %v", err) + } + if err := restoreResidualState(s.rootCtx); err != nil { log.Warnf(errRestoreResidualState, err) } From a9d06b883fe742c5dd03b822ba2385203e1b1682 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 1 Nov 2024 22:09:08 +0100 Subject: [PATCH 79/81] add all group to add peer affected peers network map check (#2830) --- management/server/peer.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index 96ede151158..7cc2209c5e0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -589,6 +589,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, fmt.Errorf("error getting account: %w", err) } + allGroup, err := account.GetGroupAll() + if err != nil { + return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) + } + + groupsToAdd = append(groupsToAdd, allGroup.ID) if areGroupChangesAffectPeers(account, groupsToAdd) { am.updateAccountPeers(ctx, account) } From 5b46cc8e9cd35c9abaf43af6d48f9a34fe3085fa Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 5 Nov 2024 13:28:42 +0100 Subject: [PATCH 80/81] Avoid failing all other matrix tests if one fails (#2839) --- .github/workflows/golang-test-linux.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index b584f0ff68c..ef66720024d 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -13,6 +13,7 @@ concurrency: jobs: test: strategy: + fail-fast: false matrix: arch: [ '386','amd64' ] store: [ 'sqlite', 'postgres'] From b952d8693d670ae56ef12348d14a3779b2535740 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 5 Nov 2024 14:51:17 +0100 Subject: [PATCH 81/81] Fix cached device flow oauth (#2833) This change removes the cached device flow oauth info when a down command is called Removing the need for the agent to be restarted --- client/server/server.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/server/server.go b/client/server/server.go index 4d921851f94..106bdf32bbf 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -626,6 +626,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes s.mutex.Lock() defer s.mutex.Unlock() + s.oauthAuthFlow = oauthAuthFlow{} + if s.actCancel == nil { return nil, fmt.Errorf("service is not up") }