diff --git a/go.mod b/go.mod index a76f321717a..d435e4eb8c3 100644 --- a/go.mod +++ b/go.mod @@ -57,8 +57,8 @@ require ( github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 + github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index 5f3826c51ce..cc7a52ed636 100644 --- a/go.sum +++ b/go.sum @@ -376,10 +376,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= -github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 h1:XTiNnVB6OEwung8WIiGJNzOTLVefuSzAA/cu+6Sst8A= -github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 h1:FNc4p8RS/gFm5jlmvUFWC4/5YxPDWejYyqEBVziFZwo= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= +github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= +github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= diff --git a/management/cmd/management.go b/management/cmd/management.go index fc7417f77f6..002cf36d887 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -242,7 +242,10 @@ var ( UserIDClaim: config.HttpConfig.AuthUserIDClaim, KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg) + + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -264,8 +267,6 @@ var ( } if !disableMetrics { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() idpManager := "disabled" if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" { idpManager = config.IdpManagerConfig.ManagerType diff --git a/management/server/account.go b/management/server/account.go index db5aece7537..1d9fd5d4866 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -72,6 +72,7 @@ type AccountManager interface { CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(accountID, userID string) error + GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) MarkPATUsed(tokenID string) error GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(accountID string) ([]*User, error) @@ -110,7 +111,7 @@ type AccountManager interface { DeleteNameServerGroup(accountID, nsGroupID, userID string) error ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string - StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) + StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(accountID, userID string) ([]*activity.Event, error) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error @@ -230,6 +231,14 @@ type Account struct { RulesG []Rule `json:"-" gorm:"-"` } +// AccountUsageStats represents the current usage statistics for an account +type AccountUsageStats struct { + ActiveUsers int64 `json:"active_users"` + TotalUsers int64 `json:"total_users"` + ActivePeers int64 `json:"active_peers"` + TotalPeers int64 `json:"total_peers"` +} + type UserInfo struct { ID string `json:"id"` Email string `json:"email"` @@ -1105,8 +1114,20 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { return nil } +// GetUsage returns the usage stats for the given account. +// This cannot be used to calculate usage stats for a period in the past as it relies on peers' last seen time. +func (am *DefaultAccountManager) GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) { + usageStats, err := am.Store.CalculateUsageStats(ctx, accountID, start, end) + if err != nil { + return nil, fmt.Errorf("failed to calculate usage stats: %w", err) + } + + return usageStats, nil +} + // GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and // userID doesn't have an account associated with it, one account is created +// domain is used to create a new account if no account is found func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) { if accountID != "" { return am.Store.GetAccount(accountID) @@ -1791,7 +1812,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } -// addAllGroup to account object if it doesn't exists +// addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { allGroup := &Group{ diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 9f4fc55589b..e179fd14d38 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -1,12 +1,14 @@ package activity +import "maps" + // Activity that triggered an Event type Activity int // Code is an activity string representation type Code struct { - message string - code string + Message string + Code string } const ( @@ -207,7 +209,7 @@ var activityMap = map[Activity]Code{ // StringCode returns a string code of the activity func (a Activity) StringCode() string { if code, ok := activityMap[a]; ok { - return code.code + return code.Code } return "UNKNOWN_ACTIVITY" } @@ -215,7 +217,12 @@ func (a Activity) StringCode() string { // Message returns a string representation of an activity func (a Activity) Message() string { if code, ok := activityMap[a]; ok { - return code.message + return code.Message } return "UNKNOWN_ACTIVITY" } + +// RegisterActivityMap adds new codes to the activity map +func RegisterActivityMap(codes map[Activity]Code) { + maps.Copy(activityMap, codes) +} diff --git a/management/server/activity/event.go b/management/server/activity/event.go index f212f5b21b3..24954712217 100644 --- a/management/server/activity/event.go +++ b/management/server/activity/event.go @@ -8,12 +8,18 @@ const ( SystemInitiator = "sys" ) +// ActivityDescriber is an interface that describes an activity +type ActivityDescriber interface { + StringCode() string + Message() string +} + // Event represents a network/system activity event. type Event struct { // Timestamp of the event Timestamp time.Time // Activity that was performed during the event - Activity Activity + Activity ActivityDescriber // ID of the event (can be empty, meaning that it wasn't yet generated) ID uint64 // InitiatorID is the ID of an object that initiated the event (e.g., a user) diff --git a/management/server/event.go b/management/server/event.go index 58ee9547f10..dd253717a93 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -54,8 +54,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit return filtered, nil } -func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, - meta map[string]any) { +func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { go func() { _, err := am.eventStore.Save(&activity.Event{ diff --git a/management/server/file_store.go b/management/server/file_store.go index 0228285cbe9..ad514781fbc 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,6 +1,8 @@ package server import ( + "context" + "fmt" "os" "path/filepath" "strings" @@ -662,3 +664,40 @@ func (s *FileStore) Close() error { func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } + +// CalculateUsageStats returns the usage stats for an account +// start and end are inclusive. +func (s *FileStore) CalculateUsageStats(_ context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, exists := s.Accounts[accountID] + if !exists { + return nil, fmt.Errorf("account not found") + } + + stats := &AccountUsageStats{ + TotalUsers: 0, + TotalPeers: int64(len(account.Peers)), + } + + for _, user := range account.Users { + if !user.IsServiceUser { + stats.TotalUsers++ + } + } + + activeUsers := make(map[string]bool) + for _, peer := range account.Peers { + lastSeen := peer.Status.LastSeen + if lastSeen.Compare(start) >= 0 && lastSeen.Compare(end) <= 0 { + if _, exists := account.Users[peer.UserID]; exists && !activeUsers[peer.UserID] { + activeUsers[peer.UserID] = true + stats.ActiveUsers++ + } + stats.ActivePeers++ + } + } + + return stats, nil +} diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index e0868fb49b6..083a062b637 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/sha256" "net" "path/filepath" @@ -657,3 +658,32 @@ func newStore(t *testing.T) *FileStore { return store } + +func TestFileStore_CalculateUsageStats(t *testing.T) { + storeDir := t.TempDir() + + err := util.CopyFileContents("testdata/store_stats.json", filepath.Join(storeDir, "store.json")) + require.NoError(t, err) + + store, err := NewFileStore(storeDir, nil) + require.NoError(t, err) + + startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC) + endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond) + + stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate) + require.NoError(t, err) + + assert.Equal(t, int64(2), stats1.ActiveUsers) + assert.Equal(t, int64(4), stats1.TotalUsers) + assert.Equal(t, int64(3), stats1.ActivePeers) + assert.Equal(t, int64(7), stats1.TotalPeers) + + stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate) + require.NoError(t, err) + + assert.Equal(t, int64(1), stats2.ActiveUsers) + assert.Equal(t, int64(2), stats2.TotalUsers) + assert.Equal(t, int64(1), stats2.ActivePeers) + assert.Equal(t, int64(2), stats2.TotalPeers) +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 75c4b277d74..4aab513a7a1 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -1,6 +1,8 @@ package http import ( + "context" + "fmt" "net/http" "github.com/gorilla/mux" @@ -15,6 +17,8 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) +const apiPrefix = "/api" + // AuthCfg contains parameters for authentication middleware type AuthCfg struct { Issuer string @@ -35,7 +39,7 @@ type emptyObject struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -61,7 +65,8 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() - router := rootRouter.PathPrefix("/api").Subrouter() + prefix := apiPrefix + router := rootRouter.PathPrefix(prefix).Subrouter() router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) api := apiHandler{ @@ -71,7 +76,10 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge AuthCfg: authCfg, } - integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor) + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil { + return nil, fmt.Errorf("register integrations endpoints: %w", err) + } + api.addAccountsEndpoint() api.addPeersEndpoint() api.addUsersEndpoint() diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 2f9374fc161..de386f173d9 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -36,9 +37,13 @@ func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessCont var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`) // Handler method of the middleware which forbids all modify requests for non admin users -// It also adds func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if bypass.ShouldBypass(r.URL.Path, h, w, r) { + return + } + claims := a.claimsExtract.FromRequestContext(r) user, err := a.getUser(claims) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 766a0c235c0..204c9f4eba1 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -66,6 +67,11 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse // Handler method of the middleware which authenticates a user either by JWT claims or by PAT func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if bypass.ShouldBypass(r.URL.Path, h, w, r) { + return + } + auth := strings.Split(r.Header.Get("Authorization"), " ") authType := strings.ToLower(auth[0]) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 5fa73ea3aa5..3e3e8419446 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -10,6 +10,7 @@ import ( "github.com/golang-jwt/jwt" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -88,39 +89,68 @@ func mockCheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error func TestAuthMiddleware_Handler(t *testing.T) { tt := []struct { name string + path string authHeader string expectedStatusCode int + shouldBypassAuth bool }{ { name: "Valid PAT Token", + path: "/test", authHeader: "Token " + PAT, expectedStatusCode: 200, }, { name: "Invalid PAT Token", + path: "/test", authHeader: "Token " + wrongToken, expectedStatusCode: 401, }, { name: "Fallback to PAT Token", + path: "/test", authHeader: "Bearer " + PAT, expectedStatusCode: 200, }, { name: "Valid JWT Token", + path: "/test", authHeader: "Bearer " + JWT, expectedStatusCode: 200, }, { name: "Invalid JWT Token", + path: "/test", authHeader: "Bearer " + wrongToken, expectedStatusCode: 401, }, { name: "Basic Auth", + path: "/test", authHeader: "Basic " + PAT, expectedStatusCode: 401, }, + { + name: "Webhook Path Bypass", + path: "/webhook", + authHeader: "", + expectedStatusCode: 200, + shouldBypassAuth: true, + }, + { + name: "Webhook Path Bypass with Subpath", + path: "/webhook/test", + authHeader: "", + expectedStatusCode: 200, + shouldBypassAuth: true, + }, + { + name: "Different Webhook Path", + path: "/webhooktest", + authHeader: "", + expectedStatusCode: 401, + shouldBypassAuth: false, + }, } nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -146,7 +176,11 @@ func TestAuthMiddleware_Handler(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest("GET", "http://testing", nil) + if tc.shouldBypassAuth { + bypass.AddBypassPath(tc.path) + } + + req := httptest.NewRequest("GET", "http://testing"+tc.path, nil) req.Header.Set("Authorization", tc.authHeader) rec := httptest.NewRecorder() @@ -159,5 +193,4 @@ func TestAuthMiddleware_Handler(t *testing.T) { } }) } - } diff --git a/management/server/http/middleware/bypass/bypass.go b/management/server/http/middleware/bypass/bypass.go new file mode 100644 index 00000000000..2f2652eb61d --- /dev/null +++ b/management/server/http/middleware/bypass/bypass.go @@ -0,0 +1,39 @@ +package bypass + +import ( + "net/http" + "sync" +) + +var byPassMutex sync.RWMutex + +// bypassPaths is a set of paths that should bypass middleware. +var bypassPaths = make(map[string]struct{}) + +// AddBypassPath adds an exact path to the list of paths that bypass middleware. +func AddBypassPath(path string) { + byPassMutex.Lock() + defer byPassMutex.Unlock() + bypassPaths[path] = struct{}{} +} + +// RemovePath removes a path from the list of paths that bypass middleware. +func RemovePath(path string) { + byPassMutex.Lock() + defer byPassMutex.Unlock() + delete(bypassPaths, path) +} + +// ShouldBypass checks if the request path is one of the auth bypass paths and returns true if the middleware should be bypassed. +// This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication. +func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool { + byPassMutex.RLock() + defer byPassMutex.RUnlock() + + if _, ok := bypassPaths[requestPath]; ok { + h.ServeHTTP(w, r) + return true + } + + return false +} diff --git a/management/server/http/middleware/bypass/bypass_test.go b/management/server/http/middleware/bypass/bypass_test.go new file mode 100644 index 00000000000..efcfe1c3d88 --- /dev/null +++ b/management/server/http/middleware/bypass/bypass_test.go @@ -0,0 +1,103 @@ +package bypass_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/http/middleware/bypass" +) + +func TestAuthBypass(t *testing.T) { + dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + pathToAdd string + pathToRemove string + testPath string + expectBypass bool + expectHTTPCode int + }{ + { + name: "Path added to bypass", + pathToAdd: "/bypass", + testPath: "/bypass", + expectBypass: true, + expectHTTPCode: http.StatusOK, + }, + { + name: "Path not added to bypass", + testPath: "/no-bypass", + expectBypass: false, + expectHTTPCode: http.StatusOK, + }, + { + name: "Path removed from bypass", + pathToAdd: "/remove-bypass", + pathToRemove: "/remove-bypass", + testPath: "/remove-bypass", + expectBypass: false, + expectHTTPCode: http.StatusOK, + }, + { + name: "Exact path matches bypass", + pathToAdd: "/webhook", + testPath: "/webhook", + expectBypass: true, + expectHTTPCode: http.StatusOK, + }, + { + name: "Subpath does not match bypass", + pathToAdd: "/webhook", + testPath: "/webhook/extra", + expectBypass: false, + expectHTTPCode: http.StatusOK, + }, + { + name: "Similar path does not match bypass", + pathToAdd: "/webhook", + testPath: "/webhooking", + expectBypass: false, + expectHTTPCode: http.StatusOK, + }, + { + name: "Prefix path does not match bypass", + pathToAdd: "/webhook", + testPath: "/web", + expectBypass: false, + expectHTTPCode: http.StatusOK, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.pathToAdd != "" { + bypass.AddBypassPath(tc.pathToAdd) + defer bypass.RemovePath(tc.pathToAdd) + } + + if tc.pathToRemove != "" { + bypass.RemovePath(tc.pathToRemove) + } + + request, err := http.NewRequest("GET", tc.testPath, nil) + require.NoError(t, err, "Creating request should not fail") + + recorder := httptest.NewRecorder() + + bypassed := bypass.ShouldBypass(tc.testPath, dummyHandler, recorder, request) + + assert.Equal(t, tc.expectBypass, bypassed, "Bypass check did not match expectation") + + if tc.expectBypass { + assert.Equal(t, tc.expectHTTPCode, recorder.Code, "HTTP status code did not match expectation for bypassed path") + } + }) + } +} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 9dd0810c7b1..7d4161d3b0e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -1,6 +1,7 @@ package mock_server import ( + "context" "net" "time" @@ -75,7 +76,7 @@ type MockAccountManager struct { CheckUserAccessByJWTGroupsFunc func(claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(accountID, userID string) error GetDNSDomainFunc func() string - StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) + StoreEventFunc func(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error @@ -91,6 +92,7 @@ type MockAccountManager struct { SavePostureChecksFunc func(accountID, userID string, postureChecks *posture.Checks) error DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) + GetUsageFunc func(ctx context.Context, accountID string, start, end time.Time) (*server.AccountUsageStats, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -646,7 +648,7 @@ func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented") } -// HasconnectedChannel mocks HasConnectedChannel of the AccountManager interface +// HasConnectedChannel mocks HasConnectedChannel of the AccountManager interface func (am *MockAccountManager) HasConnectedChannel(peerID string) bool { if am.HasConnectedChannelFunc != nil { return am.HasConnectedChannelFunc(peerID) @@ -655,7 +657,7 @@ func (am *MockAccountManager) HasConnectedChannel(peerID string) bool { } // StoreEvent mocks StoreEvent of the AccountManager interface -func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) { +func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { if am.StoreEventFunc != nil { am.StoreEventFunc(initiatorID, targetID, accountID, activityID, meta) } @@ -702,3 +704,11 @@ func (am *MockAccountManager) ListPostureChecks(accountID, userID string) ([]*po } return nil, status.Errorf(codes.Unimplemented, "method ListPostureChecks is not implemented") } + +// GetUsage mocks GetCurrentUsage of the AccountManager interface +func (am *MockAccountManager) GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*server.AccountUsageStats, error) { + if am.GetUsageFunc != nil { + return am.GetUsageFunc(ctx, accountID, start, end) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUsage is not implemented") +} diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index ab6f88c2b5a..c12ad1eddd5 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -1,6 +1,8 @@ package server import ( + "context" + "fmt" "path/filepath" "runtime" "strings" @@ -483,11 +485,11 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time return s.db.Save(user).Error } -// Close is noop in Sqlite +// Close closes the underlying DB connection func (s *SqliteStore) Close() error { sql, err := s.db.DB() if err != nil { - return err + return fmt.Errorf("get db: %w", err) } return sql.Close() } @@ -496,3 +498,48 @@ func (s *SqliteStore) Close() error { func (s *SqliteStore) GetStoreEngine() StoreEngine { return SqliteStoreEngine } + +// CalculateUsageStats returns the usage stats for an account +// start and end are inclusive. +func (s *SqliteStore) CalculateUsageStats(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) { + stats := &AccountUsageStats{} + + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := tx.Model(&nbpeer.Peer{}). + Where("account_id = ? AND peer_status_last_seen BETWEEN ? AND ?", accountID, start, end). + Distinct("user_id"). + Count(&stats.ActiveUsers).Error + if err != nil { + return fmt.Errorf("get active users: %w", err) + } + + err = tx.Model(&User{}). + Where("account_id = ? AND is_service_user = ?", accountID, false). + Count(&stats.TotalUsers).Error + if err != nil { + return fmt.Errorf("get total users: %w", err) + } + + err = tx.Model(&nbpeer.Peer{}). + Where("account_id = ? AND peer_status_last_seen BETWEEN ? AND ?", accountID, start, end). + Count(&stats.ActivePeers).Error + if err != nil { + return fmt.Errorf("get active peers: %w", err) + } + + err = tx.Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Count(&stats.TotalPeers).Error + if err != nil { + return fmt.Errorf("get total peers: %w", err) + } + + return nil + }) + + if err != nil { + return nil, fmt.Errorf("transaction: %w", err) + } + + return stats, nil +} diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index 29b49d7f3b1..e85ee211ff1 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net" "path/filepath" @@ -346,3 +347,29 @@ func newAccount(store Store, id int) error { return store.SaveAccount(account) } + +func TestSqliteStore_CalculateUsageStats(t *testing.T) { + store := newSqliteStoreFromFile(t, "testdata/store_stats.json") + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC) + endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond) + + stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate) + require.NoError(t, err) + + assert.Equal(t, int64(2), stats1.ActiveUsers) + assert.Equal(t, int64(4), stats1.TotalUsers) + assert.Equal(t, int64(3), stats1.ActivePeers) + assert.Equal(t, int64(7), stats1.TotalPeers) + + stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate) + require.NoError(t, err) + + assert.Equal(t, int64(1), stats2.ActiveUsers) + assert.Equal(t, int64(2), stats2.TotalUsers) + assert.Equal(t, int64(1), stats2.ActivePeers) + assert.Equal(t, int64(2), stats2.TotalPeers) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 1ced95e097d..66e46151948 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -1,6 +1,7 @@ package status import ( + "errors" "fmt" ) @@ -68,7 +69,8 @@ func FromError(err error) (s *Error, ok bool) { if err == nil { return nil, true } - if e, ok := err.(*Error); ok { + var e *Error + if errors.As(err, &e) { return e, true } return nil, false diff --git a/management/server/store.go b/management/server/store.go index e3a945c6477..5570108121b 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "os" "path/filepath" @@ -41,6 +42,7 @@ type Store interface { // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine + CalculateUsageStats(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) } type StoreEngine string diff --git a/management/server/testdata/store_stats.json b/management/server/testdata/store_stats.json new file mode 100644 index 00000000000..747916370c0 --- /dev/null +++ b/management/server/testdata/store_stats.json @@ -0,0 +1,161 @@ +{ + "Accounts": { + "account-1": { + "Id": "account-1", + "Domain": "example.com", + "Network": { + "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", + "Net": { + "IP": "100.64.0.0", + "Mask": "//8AAA==" + }, + "Dns": null + }, + "Users": { + "user-1-account-1": { + "Id": "user-1-account-1" + }, + "user-2-account-1": { + "Id": "user-2-account-1" + }, + "user-3-account-1": { + "Id": "user-3-account-1" + }, + "user-4-account-1": { + "Id": "user-4-account-1" + }, + "user-5-account-1": { + "Id": "user-5-account-1", + "IsServiceUser": true + } + }, + "Peers": { + "peer-1-account-1": { + "ID": "peer-1-account-1", + "UserID": "user-1-account-1", + "Status": { + "LastSeen": "2024-01-01T00:00:00Z" + }, + "Name": "Peer One", + "Meta": { + "Hostname": "peer1-host" + } + }, + "peer-2-account-1": { + "ID": "peer-2-account-1", + "UserID": "user-2-account-1", + "Status": { + "LastSeen": "2024-02-29T23:59:59Z" + }, + "Name": "Peer Two", + "Meta": { + "Hostname": "peer2-host" + } + }, + "peer-3-account-1": { + "ID": "peer-3-account-1", + "UserID": "user-2-account-1", + "Status": { + "LastSeen": "2024-02-01T12:00:00Z" + }, + "Name": "Peer Three", + "Meta": { + "Hostname": "peer3-host" + } + }, + "peer-4-account-1": { + "ID": "peer-4-account-1", + "UserID": "user-3-account-1", + "Status": { + "LastSeen": "2024-02-08T12:00:00Z" + }, + "Name": "Peer Four", + "Meta": { + "Hostname": "peer4-host" + } + }, + "peer-5-account-1": { + "ID": "peer-5-account-1", + "UserID": "user-3-account-1", + "Status": { + "LastSeen": "2023-06-01T12:00:00Z" + }, + "Name": "Peer Five", + "Meta": { + "Hostname": "peer5-host" + } + }, + "peer-6-account-1": { + "ID": "peer-6-account-1", + "UserID": "user-4-account-1", + "Status": { + "LastSeen": "2024-01-31T23:59:59Z" + }, + "Name": "Peer Six", + "Meta": { + "Hostname": "peer6-host" + } + }, + "peer-7-account-1": { + "ID": "peer-7-account-1", + "UserID": "user-4-account-1", + "Status": { + "LastSeen": "2024-03-01T00:00:00Z" + }, + "Name": "Peer Seven", + "Meta": { + "Hostname": "peer7-host" + } + } + } + }, + "account-2": { + "Id": "account-2", + "Domain": "example.org", + "Network": { + "Id": "af1c8024-ha40-4ce2-9418-34653101fc3c", + "Net": { + "IP": "100.64.0.0", + "Mask": "//8AAA==" + }, + "Dns": null + }, + "Users": { + "user-1-account-2": { + "Id": "user-1-account-2" + }, + "user-2-account-2": { + "Id": "user-1-account-2" + }, + "user-3-account-2": { + "Id": "user-3-account-2", + "IsServiceUser": true + } + }, + "Peers": { + "peer-1-account-2": { + "ID": "peer-1-account-2", + "UserID": "user-1-account-2", + "Status": { + "LastSeen": "2023-08-30T12:00:00Z" + }, + "Name": "Peer One", + "Meta": { + "Hostname": "peer1-host" + } + }, + "peer-2-account-2": { + "ID": "peer-2-account-2", + "UserID": "user-1-account-2", + "Status": { + "LastSeen": "2024-02-08T12:00:00Z" + }, + "Name": "Peer Two", + "Meta": { + "Hostname": "peer2-host" + } + } + } + } + } +}