Skip to content

Commit

Permalink
Add account usage endpoint
Browse files Browse the repository at this point in the history
- Adds bypass middleware functionality
- Passes app context to API handler
  • Loading branch information
lixmal committed Feb 12, 2024
1 parent 88747e3 commit b8c81be
Show file tree
Hide file tree
Showing 19 changed files with 544 additions and 18 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 h1:XTiNnVB6OEwung8WIiGJNzOTLVefuSzAA/cu+6Sst8A=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 h1:FNc4p8RS/gFm5jlmvUFWC4/5YxPDWejYyqEBVziFZwo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
Expand Down
8 changes: 5 additions & 3 deletions management/cmd/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"google.golang.org/grpc/keepalive"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"

"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
Expand Down Expand Up @@ -234,7 +235,10 @@ var (
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)

ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
Expand All @@ -256,8 +260,6 @@ var (
}

if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
Expand Down
26 changes: 25 additions & 1 deletion management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type AccountManager interface {
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(accountID, userID string) error
GetCurrentUsage(ctx context.Context, accountID string) (*AccountUsageStats, error)
MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error)
Expand Down Expand Up @@ -221,6 +222,14 @@ type Account struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}

// AccountUsageStats represents the current usage statistics for an account
type AccountUsageStats struct {
ActiveUsers int64 `json:"active_users"`
TotalUsers int64 `json:"total_users"`
ActivePeers int64 `json:"active_peers"`
TotalPeers int64 `json:"total_peers"`
}

type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Expand Down Expand Up @@ -1094,8 +1103,23 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return nil
}

// GetCurrentUsage returns the usage stats for the given account.
// This cannot be used to calculate usage stats for a period in the past as it relies on peers' last seen time.
func (am *DefaultAccountManager) GetCurrentUsage(ctx context.Context, accountID string) (*AccountUsageStats, error) {
now := time.Now().UTC()

startOfMonth := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, time.UTC)
usageStats, err := am.Store.CalculateUsageStats(ctx, accountID, startOfMonth, time.Now().UTC())
if err != nil {
return nil, fmt.Errorf("failed to calculate usage stats: %w", err)
}

return usageStats, nil
}

// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
Expand Down Expand Up @@ -1781,7 +1805,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
return nil
}

// addAllGroup to account object if it doesn't exists
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &Group{
Expand Down
30 changes: 30 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -660,3 +662,31 @@ func (s *FileStore) Close() error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}

// CalculateUsageStats returns the usage stats for an account
// start and end are inclusive.
func (s *FileStore) CalculateUsageStats(_ context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
account, exists := s.Accounts[accountID]
if !exists {
return nil, fmt.Errorf("account not found")
}

stats := &AccountUsageStats{
TotalUsers: int64(len(account.Users)),
TotalPeers: int64(len(account.Peers)),
}

activeUsers := make(map[string]bool)
for _, peer := range account.Peers {
if (peer.Status.LastSeen.Equal(start) || peer.Status.LastSeen.After(start)) &&
(peer.Status.LastSeen.Equal(end) || peer.Status.LastSeen.Before(end)) {
if _, exists := account.Users[peer.UserID]; exists && !activeUsers[peer.UserID] {
activeUsers[peer.UserID] = true
stats.ActiveUsers++
}
stats.ActivePeers++
}
}

return stats, nil
}
30 changes: 30 additions & 0 deletions management/server/file_store_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"crypto/sha256"
"net"
"path/filepath"
Expand Down Expand Up @@ -643,3 +644,32 @@ func newStore(t *testing.T) *FileStore {

return store
}

func TestFileStore_CalculateStats(t *testing.T) {
storeDir := t.TempDir()

err := util.CopyFileContents("testdata/store_stats.json", filepath.Join(storeDir, "store.json"))
require.NoError(t, err)

store, err := NewFileStore(storeDir, nil)
require.NoError(t, err)

startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC)
endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond)

stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate)
require.NoError(t, err)

assert.Equal(t, int64(2), stats1.ActiveUsers)
assert.Equal(t, int64(4), stats1.TotalUsers)
assert.Equal(t, int64(3), stats1.ActivePeers)
assert.Equal(t, int64(7), stats1.TotalPeers)

stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate)
require.NoError(t, err)

assert.Equal(t, int64(1), stats2.ActiveUsers)
assert.Equal(t, int64(2), stats2.TotalUsers)
assert.Equal(t, int64(1), stats2.ActivePeers)
assert.Equal(t, int64(2), stats2.TotalPeers)
}
9 changes: 9 additions & 0 deletions management/server/http/accounts_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
Expand Down Expand Up @@ -39,6 +40,14 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts
accCopy.UpdateSettings(newSettings)
return accCopy, nil
},
GetCurrentUsageFunc: func(context.Context, string, string) (*server.AccountUsageStats, error) {

Check failure on line 43 in management/server/http/accounts_handler_test.go

View workflow job for this annotation

GitHub Actions / test (sqlite)

cannot use func(context.Context, string, string) (*server.AccountUsageStats, error) {…} (value of type func("context".Context, string, string) (*server.AccountUsageStats, error)) as func(ctx "context".Context, accountID string) (*server.AccountUsageStats, error) value in struct literal

Check failure on line 43 in management/server/http/accounts_handler_test.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

cannot use func(context.Context, string, string) (*server.AccountUsageStats, error) {…} (value of type func("context".Context, string, string) (*server.AccountUsageStats, error)) as func(ctx "context".Context, accountID string) (*server.AccountUsageStats, error) value in struct literal (typecheck)

Check failure on line 43 in management/server/http/accounts_handler_test.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

cannot use func(context.Context, string, string) (*server.AccountUsageStats, error) {…} (value of type func("context".Context, string, string) (*server.AccountUsageStats, error)) as func(ctx "context".Context, accountID string) (*server.AccountUsageStats, error) value in struct literal (typecheck)

Check failure on line 43 in management/server/http/accounts_handler_test.go

View workflow job for this annotation

GitHub Actions / test (386, jsonfile)

cannot use func(context.Context, string, string) (*server.AccountUsageStats, error) {…} (value of type func("context".Context, string, string) (*server.AccountUsageStats, error)) as func(ctx "context".Context, accountID string) (*server.AccountUsageStats, error) value in struct literal

Check failure on line 43 in management/server/http/accounts_handler_test.go

View workflow job for this annotation

GitHub Actions / test (386, sqlite)

cannot use func(context.Context, string, string) (*server.AccountUsageStats, error) {…} (value of type func("context".Context, string, string) (*server.AccountUsageStats, error)) as func(ctx "context".Context, accountID string) (*server.AccountUsageStats, error) value in struct literal
return &server.AccountUsageStats{
ActiveUsers: 2,
TotalUsers: 3,
ActivePeers: 3,
TotalPeers: 6,
}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
Expand Down
14 changes: 11 additions & 3 deletions management/server/http/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package http

import (
"context"
"fmt"
"net/http"

"github.com/gorilla/mux"
Expand All @@ -14,6 +16,8 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)

const apiPrefix = "/api"

// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Expand All @@ -33,7 +37,7 @@ type emptyObject struct {
}

// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
Expand All @@ -59,7 +63,8 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()

router := rootRouter.PathPrefix("/api").Subrouter()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)

api := apiHandler{
Expand All @@ -68,7 +73,10 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
AuthCfg: authCfg,
}

integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

api.addAccountsEndpoint()
api.addPeersEndpoint()
api.addUsersEndpoint()
Expand Down
7 changes: 6 additions & 1 deletion management/server/http/middleware/access_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"

Expand Down Expand Up @@ -36,9 +37,13 @@ func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessCont
var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`)

// Handler method of the middleware which forbids all modify requests for non admin users
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}

claims := a.claimsExtract.FromRequestContext(r)

user, err := a.getUser(claims)
Expand Down
6 changes: 6 additions & 0 deletions management/server/http/middleware/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
Expand Down Expand Up @@ -66,6 +67,11 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}

auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0])

Expand Down
Loading

0 comments on commit b8c81be

Please sign in to comment.