Skip to content

Commit

Permalink
f
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Nov 15, 2024
1 parent 75a3f80 commit 5590a2c
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 9 deletions.
10 changes: 10 additions & 0 deletions client/firewall/iptables/state_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,13 @@ func (s *ShutdownState) Cleanup() error {

return nil
}

func (s *ShutdownState) Clone() any {
return &ShutdownState{
InterfaceState: s.InterfaceState,
RouteRules: s.RouteRules,
RouteIPsetCounter: s.RouteIPsetCounter,
ACLEntries: s.ACLEntries,
ACLIPsetStore: s.ACLIPsetStore,
}
}
4 changes: 4 additions & 0 deletions client/internal/dns/unclean_shutdown_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ func (s *ShutdownState) Cleanup() error {

return nil
}

func (s *ShutdownState) Clone() any {
return s
}
31 changes: 29 additions & 2 deletions client/internal/routemanager/refcounter/refcounter.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,25 @@ import (

"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

nberrors "github.com/netbirdio/netbird/client/errors"
)

// add a debug mutex
type DebugMutex struct {
mu sync.Mutex
}

func (m *DebugMutex) Lock() {
m.mu.Lock()
logCallerF("Locking")
}
func (m *DebugMutex) Unlock() {
logCallerF("Unlocking")
m.mu.Unlock()
}

const logLevel = log.TraceLevel

// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key.
Expand Down Expand Up @@ -43,11 +58,11 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
//
// The types can be aliased to a specific type using the following syntax:
//
// type RouteRefCounter = Counter[netip.Prefix, any, any]
// type RouteRefCounter = Counter[netip.Prefix, struct{}, any]
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
mu sync.Mutex
mu DebugMutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
add AddFunc[Key, I, O]
Expand Down Expand Up @@ -244,6 +259,18 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error {
return nil
}

// Clone creates a deep copy of the Counter.
func (rm *Counter[Key, I, O]) Clone() *Counter[Key, I, O] {
rm.mu.Lock()
defer rm.mu.Unlock()
return &Counter[Key, I, O]{
refCountMap: maps.Clone(rm.refCountMap),
idMap: maps.Clone(rm.idMap),
add: rm.add,
remove: rm.remove,
}
}

func getCallerInfo(depth int, maxDepth int) (string, bool) {
if depth >= maxDepth {
return "", false
Expand Down
4 changes: 4 additions & 0 deletions client/internal/routemanager/systemops/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ func (s *ShutdownState) MarshalJSON() ([]byte, error) {
func (s *ShutdownState) UnmarshalJSON(data []byte) error {
return (*ExclusionCounter)(s).UnmarshalJSON(data)
}

func (s *ShutdownState) Clone() any {
return (*ShutdownState)((*ExclusionCounter)(s).Clone())
}
35 changes: 28 additions & 7 deletions client/internal/statemanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,27 @@ import (
"io/fs"
"os"
"reflect"
"sync"
"time"

"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"

nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/util"
)

// State interface defines the methods that all state types must implement
type State interface {
Name() string
Cleanup() error
Clone() any
}

// Manager handles the persistence and management of various states
type Manager struct {
mu sync.Mutex
mu refcounter.DebugMutex
cancel context.CancelFunc
done chan struct{}

Expand Down Expand Up @@ -77,10 +78,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Unlock()

if cancel == nil {
log.Warn("state manager is not running")
return nil
}
cancel()

defer func() {
log.Debugf("ctx.Done: %v", ctx.Err())
}()

select {
case <-ctx.Done():
return ctx.Err()
Expand Down Expand Up @@ -152,7 +158,7 @@ func (m *Manager) setState(name string, state State) error {
}

func (m *Manager) periodicStateSave(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
ticker := time.NewTicker(3 * time.Second)
defer ticker.Stop()
defer close(m.done)

Expand All @@ -175,23 +181,36 @@ func (m *Manager) PersistState(ctx context.Context) error {
}

m.mu.Lock()
defer m.mu.Unlock()

if len(m.dirty) == 0 {
m.mu.Unlock()

return nil
}

bs, err := marshalWithPanicRecovery(m.states)
if err != nil {
return fmt.Errorf("marshal states: %w", err)
var states = make(map[string]State, len(m.states))
for name, state := range m.states {
if state == nil {
continue
}
if s, ok := state.Clone().(State); ok {
states[name] = s
}
}
m.mu.Unlock()

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

done := make(chan error, 1)
start := time.Now()
go func() {
bs, err := marshalWithPanicRecovery(states)
if err != nil {
done <- fmt.Errorf("marshal states: %w", err)
return
}

done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
}()

Expand All @@ -206,7 +225,9 @@ func (m *Manager) PersistState(ctx context.Context) error {

log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))

m.mu.Lock()
clear(m.dirty)
m.mu.Unlock()

return nil
}
Expand Down

0 comments on commit 5590a2c

Please sign in to comment.