Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add account usage logic #1567

Merged
merged 15 commits into from
Feb 22, 2024
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ require (
github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -376,10 +376,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22 h1:XTiNnVB6OEwung8WIiGJNzOTLVefuSzAA/cu+6Sst8A=
github.com/netbirdio/management-integrations/additions v0.0.0-20240118163419-8a7c87accb22/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22 h1:FNc4p8RS/gFm5jlmvUFWC4/5YxPDWejYyqEBVziFZwo=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240118163419-8a7c87accb22/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4=
github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=
Expand Down
7 changes: 4 additions & 3 deletions management/cmd/management.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ var (
UserIDClaim: config.HttpConfig.AuthUserIDClaim,
KeysLocation: config.HttpConfig.AuthKeysLocation,
}
httpAPIHandler, err := httpapi.APIHandler(accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)

ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg)
if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err)
}
Expand All @@ -264,8 +267,6 @@ var (
}

if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
idpManager := "disabled"
if config.IdpManagerConfig != nil && config.IdpManagerConfig.ManagerType != "" {
idpManager = config.IdpManagerConfig.ManagerType
Expand Down
25 changes: 23 additions & 2 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ type AccountManager interface {
CheckUserAccessByJWTGroups(claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(accountID, userID string) error
GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error)
MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error)
Expand Down Expand Up @@ -110,7 +111,7 @@ type AccountManager interface {
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error)
GetDNSDomain() string
StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any)
StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any)
GetEvents(accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
Expand Down Expand Up @@ -230,6 +231,14 @@ type Account struct {
RulesG []Rule `json:"-" gorm:"-"`
}

// AccountUsageStats represents the current usage statistics for an account
type AccountUsageStats struct {
ActiveUsers int64 `json:"active_users"`
TotalUsers int64 `json:"total_users"`
surik marked this conversation as resolved.
Show resolved Hide resolved
ActivePeers int64 `json:"active_peers"`
TotalPeers int64 `json:"total_peers"`
}

type UserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Expand Down Expand Up @@ -1105,8 +1114,20 @@ func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error {
return nil
}

// GetUsage returns the usage stats for the given account.
// This cannot be used to calculate usage stats for a period in the past as it relies on peers' last seen time.
func (am *DefaultAccountManager) GetUsage(ctx context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
usageStats, err := am.Store.CalculateUsageStats(ctx, accountID, start, end)
if err != nil {
return nil, fmt.Errorf("failed to calculate usage stats: %w", err)
}

return usageStats, nil
}

// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
// userID doesn't have an account associated with it, one account is created
// domain is used to create a new account if no account is found
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) {
if accountID != "" {
return am.Store.GetAccount(accountID)
Expand Down Expand Up @@ -1791,7 +1812,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut
return nil
}

// addAllGroup to account object if it doesn't exists
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
allGroup := &Group{
Expand Down
15 changes: 11 additions & 4 deletions management/server/activity/codes.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package activity

import "maps"

// Activity that triggered an Event
type Activity int

// Code is an activity string representation
type Code struct {
message string
code string
Message string
Code string
}

const (
Expand Down Expand Up @@ -207,15 +209,20 @@ var activityMap = map[Activity]Code{
// StringCode returns a string code of the activity
func (a Activity) StringCode() string {
if code, ok := activityMap[a]; ok {
return code.code
return code.Code
}
return "UNKNOWN_ACTIVITY"
}

// Message returns a string representation of an activity
func (a Activity) Message() string {
if code, ok := activityMap[a]; ok {
return code.message
return code.Message
}
return "UNKNOWN_ACTIVITY"
}

// RegisterActivityMap adds new codes to the activity map
func RegisterActivityMap(codes map[Activity]Code) {
maps.Copy(activityMap, codes)
}
8 changes: 7 additions & 1 deletion management/server/activity/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@ const (
SystemInitiator = "sys"
)

// ActivityDescriber is an interface that describes an activity
type ActivityDescriber interface {
StringCode() string
Message() string
}

// Event represents a network/system activity event.
type Event struct {
// Timestamp of the event
Timestamp time.Time
// Activity that was performed during the event
Activity Activity
Activity ActivityDescriber
// ID of the event (can be empty, meaning that it wasn't yet generated)
ID uint64
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
Expand Down
3 changes: 1 addition & 2 deletions management/server/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit
return filtered, nil
}

func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity,
meta map[string]any) {
func (am *DefaultAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {

go func() {
_, err := am.eventStore.Save(&activity.Event{
Expand Down
39 changes: 39 additions & 0 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package server

import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -662,3 +664,40 @@ func (s *FileStore) Close() error {
func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}

// CalculateUsageStats returns the usage stats for an account
// start and end are inclusive.
func (s *FileStore) CalculateUsageStats(_ context.Context, accountID string, start time.Time, end time.Time) (*AccountUsageStats, error) {
s.mux.Lock()
defer s.mux.Unlock()

account, exists := s.Accounts[accountID]
lixmal marked this conversation as resolved.
Show resolved Hide resolved
if !exists {
return nil, fmt.Errorf("account not found")
}

stats := &AccountUsageStats{
TotalUsers: 0,
TotalPeers: int64(len(account.Peers)),
}

for _, user := range account.Users {
if !user.IsServiceUser {
stats.TotalUsers++
}
}

activeUsers := make(map[string]bool)
for _, peer := range account.Peers {
lastSeen := peer.Status.LastSeen
if lastSeen.Compare(start) >= 0 && lastSeen.Compare(end) <= 0 {
if _, exists := account.Users[peer.UserID]; exists && !activeUsers[peer.UserID] {
activeUsers[peer.UserID] = true
stats.ActiveUsers++
}
stats.ActivePeers++
}
}

return stats, nil
}
30 changes: 30 additions & 0 deletions management/server/file_store_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"crypto/sha256"
"net"
"path/filepath"
Expand Down Expand Up @@ -657,3 +658,32 @@ func newStore(t *testing.T) *FileStore {

return store
}

func TestFileStore_CalculateUsageStats(t *testing.T) {
storeDir := t.TempDir()

err := util.CopyFileContents("testdata/store_stats.json", filepath.Join(storeDir, "store.json"))
require.NoError(t, err)

store, err := NewFileStore(storeDir, nil)
require.NoError(t, err)

startDate := time.Date(2024, time.February, 1, 0, 0, 0, 0, time.UTC)
endDate := startDate.AddDate(0, 1, 0).Add(-time.Nanosecond)

stats1, err := store.CalculateUsageStats(context.TODO(), "account-1", startDate, endDate)
require.NoError(t, err)

assert.Equal(t, int64(2), stats1.ActiveUsers)
assert.Equal(t, int64(4), stats1.TotalUsers)
assert.Equal(t, int64(3), stats1.ActivePeers)
assert.Equal(t, int64(7), stats1.TotalPeers)

stats2, err := store.CalculateUsageStats(context.TODO(), "account-2", startDate, endDate)
require.NoError(t, err)

assert.Equal(t, int64(1), stats2.ActiveUsers)
assert.Equal(t, int64(2), stats2.TotalUsers)
assert.Equal(t, int64(1), stats2.ActivePeers)
assert.Equal(t, int64(2), stats2.TotalPeers)
}
14 changes: 11 additions & 3 deletions management/server/http/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package http

import (
"context"
"fmt"
"net/http"

"github.com/gorilla/mux"
Expand All @@ -15,6 +17,8 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)

const apiPrefix = "/api"

// AuthCfg contains parameters for authentication middleware
type AuthCfg struct {
Issuer string
Expand All @@ -35,7 +39,7 @@ type emptyObject struct {
}

// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
Expand All @@ -61,7 +65,8 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge
rootRouter := mux.NewRouter()
metricsMiddleware := appMetrics.HTTPMiddleware()

router := rootRouter.PathPrefix("/api").Subrouter()
prefix := apiPrefix
router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)

api := apiHandler{
Expand All @@ -71,7 +76,10 @@ func APIHandler(accountManager s.AccountManager, LocationManager *geolocation.Ge
AuthCfg: authCfg,
}

integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}

api.addAccountsEndpoint()
api.addPeersEndpoint()
api.addUsersEndpoint()
Expand Down
7 changes: 6 additions & 1 deletion management/server/http/middleware/access_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"

Expand Down Expand Up @@ -36,9 +37,13 @@ func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessCont
var tokenPathRegexp = regexp.MustCompile(`^.*/api/users/.*/tokens.*$`)

// Handler method of the middleware which forbids all modify requests for non admin users
// It also adds
func (a *AccessControl) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}

claims := a.claimsExtract.FromRequestContext(r)

user, err := a.getUser(claims)
Expand Down
6 changes: 6 additions & 0 deletions management/server/http/middleware/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
Expand Down Expand Up @@ -66,6 +67,11 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}

auth := strings.Split(r.Header.Get("Authorization"), " ")
authType := strings.ToLower(auth[0])

Expand Down
Loading
Loading