Skip to content

Commit

Permalink
Add account usage logic (#1567)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Yury Gargay <yury.gargay@gmail.com>
  • Loading branch information
lixmal and surik authored Feb 22, 2024
1 parent e18bf56 commit b7a6cbf
Show file tree
Hide file tree
Showing 21 changed files with 576 additions and 30 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 h1:XTiNnVB6OEwung8WIiGJNzOTLVefuSzAA/cu+6Sst8A=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 h1:FNc4p8RS/gFm5jlmvUFWC4/5YxPDWejYyqEBVziFZwo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
Expand Down
7 changes: 4 additions & 3 deletions management/cmd/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ var (
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)

ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
Expand All @@ -264,8 +267,6 @@ var (
}

if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
Expand Down
25 changes: 23 additions & 2 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type AccountManager interface {
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(accountID, userID string) error
GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error)
MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error)
Expand Down Expand Up @@ -110,7 +111,7 @@ type AccountManager interface {
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any)
StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
Expand Down Expand Up @@ -230,6 +231,14 @@ type Account struct {
RulesG []Rule `json:"-" gorm:"-"`
}

// AccountUsageStats represents the current usage statistics for an account
type AccountUsageStats struct {
ActiveUsers int64 `json:"active_users"`
TotalUsers int64 `json:"total_users"`
ActivePeers int64 `json:"active_peers"`
TotalPeers int64 `json:"total_peers"`
}

type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Expand Down Expand Up @@ -1105,8 +1114,20 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return nil
}

// GetUsage returns the usage stats for the given account.
// This cannot be used to calculate usage stats for a period in the past as it relies on peers' last seen time.
func (am *DefaultAccountManager) GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
usageStats, err := am.Store.CalculateUsageStats(ctx, accountID, start, end)
if err != nil {
return nil, fmt.Errorf("failed to calculate usage stats: %w", err)
}

return usageStats, nil
}

// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
Expand Down Expand Up @@ -1791,7 +1812,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
return nil
}

// addAllGroup to account object if it doesn't exists
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &Group{
Expand Down
15 changes: 11 additions & 4 deletions management/server/activity/codes.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package activity

import "maps"

// Activity that triggered an Event
type Activity int

// Code is an activity string representation
type Code struct {
message string
code string
Message string
Code string
}

const (
Expand Down Expand Up @@ -207,15 +209,20 @@ var activityMap = map[Activity]Code{
// StringCode returns a string code of the activity
func (a Activity) StringCode() string {
if code, ok := activityMap[a]; ok {
return code.code
return code.Code
}
return "UNKNOWN_ACTIVITY"
}

// Message returns a string representation of an activity
func (a Activity) Message() string {
if code, ok := activityMap[a]; ok {
return code.message
return code.Message
}
return "UNKNOWN_ACTIVITY"
}

// RegisterActivityMap adds new codes to the activity map
func RegisterActivityMap(codes map[Activity]Code) {
maps.Copy(activityMap, codes)
}
8 changes: 7 additions & 1 deletion management/server/activity/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ const (
SystemInitiator = "sys"
)

// ActivityDescriber is an interface that describes an activity
type ActivityDescriber interface {
StringCode() string
Message() string
}

// Event represents a network/system activity event.
type Event struct {
// Timestamp of the event
Timestamp time.Time
// Activity that was performed during the event
Activity Activity
Activity ActivityDescriber
// ID of the event (can be empty, meaning that it wasn't yet generated)
ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
Expand Down
3 changes: 1 addition & 2 deletions management/server/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return filtered, nil
}

func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity,
meta map[string]any) {
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {

go func() {
_, err := am.eventStore.Save(&activity.Event{
Expand Down
39 changes: 39 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -662,3 +664,40 @@ func (s *FileStore) Close() error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}

// CalculateUsageStats returns the usage stats for an account
// start and end are inclusive.
func (s *FileStore) CalculateUsageStats(_ context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
s.mux.Lock()
defer s.mux.Unlock()

account, exists := s.Accounts[accountID]
if !exists {
return nil, fmt.Errorf("account not found")
}

stats := &AccountUsageStats{
TotalUsers: 0,
TotalPeers: int64(len(account.Peers)),
}

for _, user := range account.Users {
if !user.IsServiceUser {
stats.TotalUsers++
}
}

activeUsers := make(map[string]bool)
for _, peer := range account.Peers {
lastSeen := peer.Status.LastSeen
if lastSeen.Compare(start) >= 0 && lastSeen.Compare(end) <= 0 {
if _, exists := account.Users[peer.UserID]; exists && !activeUsers[peer.UserID] {
activeUsers[peer.UserID] = true
stats.ActiveUsers++
}
stats.ActivePeers++
}
}

return stats, nil
}
30 changes: 30 additions & 0 deletions management/server/file_store_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"crypto/sha256"
"net"
"path/filepath"
Expand Down Expand Up @@ -657,3 +658,32 @@ func newStore(t *testing.T) *FileStore {

return store
}

func TestFileStore_CalculateUsageStats(t *testing.T) {
storeDir := t.TempDir()

err := util.CopyFileContents("testdata/store_stats.json", filepath.Join(storeDir, "store.json"))
require.NoError(t, err)

store, err := NewFileStore(storeDir, nil)
require.NoError(t, err)

startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC)
endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond)

stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate)
require.NoError(t, err)

assert.Equal(t, int64(2), stats1.ActiveUsers)
assert.Equal(t, int64(4), stats1.TotalUsers)
assert.Equal(t, int64(3), stats1.ActivePeers)
assert.Equal(t, int64(7), stats1.TotalPeers)

stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate)
require.NoError(t, err)

assert.Equal(t, int64(1), stats2.ActiveUsers)
assert.Equal(t, int64(2), stats2.TotalUsers)
assert.Equal(t, int64(1), stats2.ActivePeers)
assert.Equal(t, int64(2), stats2.TotalPeers)
}
14 changes: 11 additions & 3 deletions management/server/http/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package http

import (
"context"
"fmt"
"net/http"

"github.com/gorilla/mux"
Expand All @@ -15,6 +17,8 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)

const apiPrefix = "/api"

// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Expand All @@ -35,7 +39,7 @@ type emptyObject struct {
}

// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
Expand All @@ -61,7 +65,8 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()

router := rootRouter.PathPrefix("/api").Subrouter()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)

api := apiHandler{
Expand All @@ -71,7 +76,10 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge
AuthCfg: authCfg,
}

integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

api.addAccountsEndpoint()
api.addPeersEndpoint()
api.addUsersEndpoint()
Expand Down
7 changes: 6 additions & 1 deletion management/server/http/middleware/access_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"

Expand Down Expand Up @@ -36,9 +37,13 @@ func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessCont
var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`)

// Handler method of the middleware which forbids all modify requests for non admin users
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}

claims := a.claimsExtract.FromRequestContext(r)

user, err := a.getUser(claims)
Expand Down
6 changes: 6 additions & 0 deletions management/server/http/middleware/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
Expand Down Expand Up @@ -66,6 +67,11 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}

auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0])

Expand Down
Loading

0 comments on commit b7a6cbf

Please sign in to comment.