From 7488c13a66502b32cab044756aaee11296c6a921 Mon Sep 17 00:00:00 2001 From: kinggo Date: Mon, 16 Jan 2023 23:52:16 +0800 Subject: [PATCH] reactor: modify storage inteaface and add contxt.Context --- app.go | 10 +++---- internal/memory/memory.go | 9 +++--- internal/storage/memory/memory.go | 13 ++++---- internal/storage/memory/memory_test.go | 41 +++++++++++++------------- middleware/cache/cache.go | 21 ++++++------- middleware/cache/manager.go | 31 +++++++++---------- middleware/csrf/csrf.go | 4 +-- middleware/csrf/manager.go | 31 +++++++++---------- middleware/idempotency/idempotency.go | 4 +-- middleware/limiter/limiter_fixed.go | 8 ++--- middleware/limiter/limiter_sliding.go | 4 +-- middleware/limiter/manager.go | 31 +++++++++---------- middleware/session/session.go | 13 ++++---- middleware/session/session_test.go | 37 ++++++++++++----------- middleware/session/store.go | 7 +++-- 15 files changed, 137 insertions(+), 127 deletions(-) diff --git a/app.go b/app.go index e299075b37e..560eba64df3 100644 --- a/app.go +++ b/app.go @@ -41,23 +41,23 @@ type Map map[string]interface{} type Storage interface { // Get gets the value for the given key. // `nil, nil` is returned when the key does not exist - Get(key string) ([]byte, error) + Get(ctx context.Context, key string) ([]byte, error) // Set stores the given value for the given key along // with an expiration value, 0 means no expiration. // Empty key or value will be ignored without an error. - Set(key string, val []byte, exp time.Duration) error + Set(ctx context.Context, key string, val []byte, exp time.Duration) error // Delete deletes the value for the given key. // It returns no error if the storage does not contain the key, - Delete(key string) error + Delete(ctx context.Context, key string) error // Reset resets the storage and delete all keys. - Reset() error + Reset(ctx context.Context) error // Close closes the storage and will stop any running garbage // collectors and open connections. - Close() error + Close(ctx context.Context) error } // ErrorHandler defines a function that will process all errors diff --git a/internal/memory/memory.go b/internal/memory/memory.go index d7b053de465..b742aed820b 100644 --- a/internal/memory/memory.go +++ b/internal/memory/memory.go @@ -3,6 +3,7 @@ package memory import ( + "context" "sync" "sync/atomic" "time" @@ -31,7 +32,7 @@ func New() *Storage { } // Get value by key -func (s *Storage) Get(key string) interface{} { +func (s *Storage) Get(_ context.Context, key string) interface{} { s.RLock() v, ok := s.data[key] s.RUnlock() @@ -42,7 +43,7 @@ func (s *Storage) Get(key string) interface{} { } // Set key with value -func (s *Storage) Set(key string, val interface{}, ttl time.Duration) { +func (s *Storage) Set(_ context.Context, key string, val interface{}, ttl time.Duration) { var exp uint32 if ttl > 0 { exp = uint32(ttl.Seconds()) + atomic.LoadUint32(&utils.Timestamp) @@ -54,14 +55,14 @@ func (s *Storage) Set(key string, val interface{}, ttl time.Duration) { } // Delete key by key -func (s *Storage) Delete(key string) { +func (s *Storage) Delete(_ context.Context, key string) { s.Lock() delete(s.data, key) s.Unlock() } // Reset all keys -func (s *Storage) Reset() { +func (s *Storage) Reset(_ context.Context) { nd := make(map[string]item) s.Lock() s.data = nd diff --git a/internal/storage/memory/memory.go b/internal/storage/memory/memory.go index 1a561061068..2183fa26287 100644 --- a/internal/storage/memory/memory.go +++ b/internal/storage/memory/memory.go @@ -3,6 +3,7 @@ package memory import ( + "context" "sync" "sync/atomic" "time" @@ -44,7 +45,7 @@ func New(config ...Config) *Storage { } // Get value by key -func (s *Storage) Get(key string) ([]byte, error) { +func (s *Storage) Get(_ context.Context, key string) ([]byte, error) { if len(key) <= 0 { return nil, nil } @@ -59,7 +60,7 @@ func (s *Storage) Get(key string) ([]byte, error) { } // Set key with value -func (s *Storage) Set(key string, val []byte, exp time.Duration) error { +func (s *Storage) Set(_ context.Context, key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if len(key) <= 0 || len(val) <= 0 { return nil @@ -78,7 +79,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error { } // Delete key by key -func (s *Storage) Delete(key string) error { +func (s *Storage) Delete(_ context.Context, key string) error { // Ain't Nobody Got Time For That if len(key) <= 0 { return nil @@ -90,7 +91,7 @@ func (s *Storage) Delete(key string) error { } // Reset all keys -func (s *Storage) Reset() error { +func (s *Storage) Reset(_ context.Context) error { ndb := make(map[string]entry) s.mux.Lock() s.db = ndb @@ -99,7 +100,7 @@ func (s *Storage) Reset() error { } // Close the memory storage -func (s *Storage) Close() error { +func (s *Storage) Close(_ context.Context) error { s.done <- struct{}{} return nil } @@ -137,7 +138,7 @@ func (s *Storage) gc() { } } -// Return database client +// Conn Return database client func (s *Storage) Conn() map[string]entry { return s.db } diff --git a/internal/storage/memory/memory_test.go b/internal/storage/memory/memory_test.go index fb2b88a0e58..5b432e6e8ad 100644 --- a/internal/storage/memory/memory_test.go +++ b/internal/storage/memory/memory_test.go @@ -1,6 +1,7 @@ package memory import ( + "context" "testing" "time" @@ -16,7 +17,7 @@ func Test_Storage_Memory_Set(t *testing.T) { val = []byte("doe") ) - err := testStore.Set(key, val, 0) + err := testStore.Set(context.TODO(), key, val, 0) utils.AssertEqual(t, nil, err) } @@ -27,10 +28,10 @@ func Test_Storage_Memory_Set_Override(t *testing.T) { val = []byte("doe") ) - err := testStore.Set(key, val, 0) + err := testStore.Set(context.TODO(), key, val, 0) utils.AssertEqual(t, nil, err) - err = testStore.Set(key, val, 0) + err = testStore.Set(context.TODO(), key, val, 0) utils.AssertEqual(t, nil, err) } @@ -41,10 +42,10 @@ func Test_Storage_Memory_Get(t *testing.T) { val = []byte("doe") ) - err := testStore.Set(key, val, 0) + err := testStore.Set(context.TODO(), key, val, 0) utils.AssertEqual(t, nil, err) - result, err := testStore.Get(key) + result, err := testStore.Get(context.TODO(), key) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, val, result) } @@ -57,7 +58,7 @@ func Test_Storage_Memory_Set_Expiration(t *testing.T) { exp = 1 * time.Second ) - err := testStore.Set(key, val, exp) + err := testStore.Set(context.TODO(), key, val, exp) utils.AssertEqual(t, nil, err) time.Sleep(1100 * time.Millisecond) @@ -68,7 +69,7 @@ func Test_Storage_Memory_Get_Expired(t *testing.T) { key = "john" ) - result, err := testStore.Get(key) + result, err := testStore.Get(context.TODO(), key) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, len(result) == 0) } @@ -76,7 +77,7 @@ func Test_Storage_Memory_Get_Expired(t *testing.T) { func Test_Storage_Memory_Get_NotExist(t *testing.T) { t.Parallel() - result, err := testStore.Get("notexist") + result, err := testStore.Get(context.TODO(), "notexist") utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, len(result) == 0) } @@ -88,13 +89,13 @@ func Test_Storage_Memory_Delete(t *testing.T) { val = []byte("doe") ) - err := testStore.Set(key, val, 0) + err := testStore.Set(context.TODO(), key, val, 0) utils.AssertEqual(t, nil, err) - err = testStore.Delete(key) + err = testStore.Delete(context.TODO(), key) utils.AssertEqual(t, nil, err) - result, err := testStore.Get(key) + result, err := testStore.Get(context.TODO(), key) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, len(result) == 0) } @@ -105,27 +106,27 @@ func Test_Storage_Memory_Reset(t *testing.T) { val = []byte("doe") ) - err := testStore.Set("john1", val, 0) + err := testStore.Set(context.TODO(), "john1", val, 0) utils.AssertEqual(t, nil, err) - err = testStore.Set("john2", val, 0) + err = testStore.Set(context.TODO(), "john2", val, 0) utils.AssertEqual(t, nil, err) - err = testStore.Reset() + err = testStore.Reset(context.TODO()) utils.AssertEqual(t, nil, err) - result, err := testStore.Get("john1") + result, err := testStore.Get(context.TODO(), "john1") utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, len(result) == 0) - result, err = testStore.Get("john2") + result, err = testStore.Get(context.TODO(), "john2") utils.AssertEqual(t, nil, err) utils.AssertEqual(t, true, len(result) == 0) } func Test_Storage_Memory_Close(t *testing.T) { t.Parallel() - utils.AssertEqual(t, nil, testStore.Close()) + utils.AssertEqual(t, nil, testStore.Close(context.TODO())) } func Test_Storage_Memory_Conn(t *testing.T) { @@ -149,13 +150,13 @@ func Benchmark_Storage_Memory(b *testing.B) { b.ResetTimer() for n := 0; n < b.N; n++ { for _, key := range keys { - d.Set(key, value, ttl) + d.Set(context.TODO(), key, value, ttl) } for _, key := range keys { - _, _ = d.Get(key) + _, _ = d.Get(context.TODO(), key) } for _, key := range keys { - d.Delete(key) + d.Delete(context.TODO(), key) } } }) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index f6db49e2cf1..4a34e9ce864 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -3,6 +3,7 @@ package cache import ( + "context" "strconv" "strings" "sync" @@ -80,11 +81,11 @@ func New(config ...Config) fiber.Handler { }() // Delete key from both manager and storage - deleteKey := func(dkey string) { - manager.delete(dkey) + deleteKey := func(ctx context.Context, dkey string) { + manager.delete(ctx, dkey) // External storage saves body data with different key if cfg.Storage != nil { - manager.delete(dkey + "_body") + manager.delete(ctx, dkey+"_body") } } @@ -113,7 +114,7 @@ func New(config ...Config) fiber.Handler { key := cfg.KeyGenerator(c) + "_" + c.Method() // Get entry from pool - e := manager.get(key) + e := manager.get(c.Context(), key) // Lock entry mux.Lock() @@ -123,7 +124,7 @@ func New(config ...Config) fiber.Handler { // Check if entry is expired if e.exp != 0 && ts >= e.exp { - deleteKey(key) + deleteKey(c.Context(), key) if cfg.MaxBytes > 0 { _, size := heap.remove(e.heapidx) storedBytes -= size @@ -132,7 +133,7 @@ func New(config ...Config) fiber.Handler { // Separate body value to avoid msgp serialization // We can store raw bytes with Storage 👍 if cfg.Storage != nil { - e.body = manager.getRaw(key + "_body") + e.body = manager.getRaw(c.Context(), key+"_body") } // Set response headers from cache c.Response().SetBodyRaw(e.body) @@ -189,7 +190,7 @@ func New(config ...Config) fiber.Handler { if cfg.MaxBytes > 0 { for storedBytes+bodySize > cfg.MaxBytes { key, size := heap.removeFirst() - deleteKey(key) + deleteKey(c.Context(), key) storedBytes -= size } } @@ -231,14 +232,14 @@ func New(config ...Config) fiber.Handler { // For external Storage we store raw body separated if cfg.Storage != nil { - manager.setRaw(key+"_body", e.body, expiration) + manager.setRaw(c.Context(), key+"_body", e.body, expiration) // avoid body msgp encoding e.body = nil - manager.set(key, e, expiration) + manager.set(c.Context(), key, e, expiration) manager.release(e) } else { // Store entry in memory - manager.set(key, e, expiration) + manager.set(c.Context(), key, e, expiration) } c.Set(cfg.CacheHeader, cacheMiss) diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index 6b9256fd232..f4835b5814b 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -1,6 +1,7 @@ package cache import ( + "context" "sync" "time" @@ -69,57 +70,57 @@ func (m *manager) release(e *item) { } // get data from storage or memory -func (m *manager) get(key string) (it *item) { +func (m *manager) get(ctx context.Context, key string) (it *item) { if m.storage != nil { it = m.acquire() - if raw, _ := m.storage.Get(key); raw != nil { + if raw, _ := m.storage.Get(ctx, key); raw != nil { if _, err := it.UnmarshalMsg(raw); err != nil { return } } return } - if it, _ = m.memory.Get(key).(*item); it == nil { + if it, _ = m.memory.Get(ctx, key).(*item); it == nil { it = m.acquire() } return } // get raw data from storage or memory -func (m *manager) getRaw(key string) (raw []byte) { +func (m *manager) getRaw(ctx context.Context, key string) (raw []byte) { if m.storage != nil { - raw, _ = m.storage.Get(key) + raw, _ = m.storage.Get(ctx, key) } else { - raw, _ = m.memory.Get(key).([]byte) + raw, _ = m.memory.Get(ctx, key).([]byte) } return } // set data to storage or memory -func (m *manager) set(key string, it *item, exp time.Duration) { +func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) { if m.storage != nil { if raw, err := it.MarshalMsg(nil); err == nil { - _ = m.storage.Set(key, raw, exp) + _ = m.storage.Set(ctx, key, raw, exp) } } else { - m.memory.Set(key, it, exp) + m.memory.Set(ctx, key, it, exp) } } // set data to storage or memory -func (m *manager) setRaw(key string, raw []byte, exp time.Duration) { +func (m *manager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) { if m.storage != nil { - _ = m.storage.Set(key, raw, exp) + _ = m.storage.Set(ctx, key, raw, exp) } else { - m.memory.Set(key, raw, exp) + m.memory.Set(ctx, key, raw, exp) } } // delete data from storage or memory -func (m *manager) delete(key string) { +func (m *manager) delete(ctx context.Context, key string) { if m.storage != nil { - _ = m.storage.Delete(key) + _ = m.storage.Delete(ctx, key) } else { - m.memory.Delete(key) + m.memory.Delete(ctx, key) } } diff --git a/middleware/csrf/csrf.go b/middleware/csrf/csrf.go index e7ad4f2a722..a3d782443d6 100644 --- a/middleware/csrf/csrf.go +++ b/middleware/csrf/csrf.go @@ -45,7 +45,7 @@ func New(config ...Config) fiber.Handler { } // if token does not exist in Storage - if manager.getRaw(token) == nil { + if manager.getRaw(c.Context(), token) == nil { // Expire cookie c.Cookie(&fiber.Cookie{ Name: cfg.CookieName, @@ -68,7 +68,7 @@ func New(config ...Config) fiber.Handler { } // Add/update token to Storage - manager.setRaw(token, dummyValue, cfg.Expiration) + manager.setRaw(c.Context(), token, dummyValue, cfg.Expiration) // Create cookie to pass token to client cookie := &fiber.Cookie{ diff --git a/middleware/csrf/manager.go b/middleware/csrf/manager.go index 13f0ccb657a..c6f8899b172 100644 --- a/middleware/csrf/manager.go +++ b/middleware/csrf/manager.go @@ -1,6 +1,7 @@ package csrf import ( + "context" "sync" "time" @@ -56,59 +57,59 @@ func (m *manager) release(e *item) { } // get data from storage or memory -func (m *manager) get(key string) (it *item) { +func (m *manager) get(ctx context.Context, key string) (it *item) { if m.storage != nil { it = m.acquire() - if raw, _ := m.storage.Get(key); raw != nil { + if raw, _ := m.storage.Get(ctx, key); raw != nil { if _, err := it.UnmarshalMsg(raw); err != nil { return } } return } - if it, _ = m.memory.Get(key).(*item); it == nil { + if it, _ = m.memory.Get(ctx, key).(*item); it == nil { it = m.acquire() } return } // get raw data from storage or memory -func (m *manager) getRaw(key string) (raw []byte) { +func (m *manager) getRaw(ctx context.Context, key string) (raw []byte) { if m.storage != nil { - raw, _ = m.storage.Get(key) + raw, _ = m.storage.Get(ctx, key) } else { - raw, _ = m.memory.Get(key).([]byte) + raw, _ = m.memory.Get(ctx, key).([]byte) } return } // set data to storage or memory -func (m *manager) set(key string, it *item, exp time.Duration) { +func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) { if m.storage != nil { if raw, err := it.MarshalMsg(nil); err == nil { - _ = m.storage.Set(key, raw, exp) + _ = m.storage.Set(ctx, key, raw, exp) } } else { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here - m.memory.Set(utils.CopyString(key), it, exp) + m.memory.Set(ctx, utils.CopyString(key), it, exp) } } // set data to storage or memory -func (m *manager) setRaw(key string, raw []byte, exp time.Duration) { +func (m *manager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) { if m.storage != nil { - _ = m.storage.Set(key, raw, exp) + _ = m.storage.Set(ctx, key, raw, exp) } else { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here - m.memory.Set(utils.CopyString(key), raw, exp) + m.memory.Set(ctx, utils.CopyString(key), raw, exp) } } // delete data from storage or memory -func (m *manager) delete(key string) { +func (m *manager) delete(ctx context.Context, key string) { if m.storage != nil { - _ = m.storage.Delete(key) + _ = m.storage.Delete(ctx, key) } else { - m.memory.Delete(key) + m.memory.Delete(ctx, key) } } diff --git a/middleware/idempotency/idempotency.go b/middleware/idempotency/idempotency.go index 3ad2b9dedce..08ceab2551f 100644 --- a/middleware/idempotency/idempotency.go +++ b/middleware/idempotency/idempotency.go @@ -35,7 +35,7 @@ func New(config ...Config) fiber.Handler { } maybeWriteCachedResponse := func(c *fiber.Ctx, key string) (bool, error) { - if val, err := cfg.Storage.Get(key); err != nil { + if val, err := cfg.Storage.Get(c.Context(), key); err != nil { return false, fmt.Errorf("failed to read response: %w", err) } else if val != nil { var res response @@ -138,7 +138,7 @@ func New(config ...Config) fiber.Handler { } // Store response - if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil { + if err := cfg.Storage.Set(c.Context(), key, bs, cfg.Lifetime); err != nil { return fmt.Errorf("failed to save response: %w", err) } diff --git a/middleware/limiter/limiter_fixed.go b/middleware/limiter/limiter_fixed.go index b6b6d35939e..c902a7d3b54 100644 --- a/middleware/limiter/limiter_fixed.go +++ b/middleware/limiter/limiter_fixed.go @@ -40,7 +40,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { mux.Lock() // Get entry from pool and release when finished - e := manager.get(key) + e := manager.get(c.Context(), key) // Get timestamp ts := uint64(atomic.LoadUint32(&utils.Timestamp)) @@ -64,7 +64,7 @@ func (FixedWindow) New(cfg Config) fiber.Handler { remaining := cfg.Max - e.currHits // Update storage - manager.set(key, e, cfg.Expiration) + manager.set(c.Context(), key, e, cfg.Expiration) // Unlock entry mux.Unlock() @@ -88,10 +88,10 @@ func (FixedWindow) New(cfg Config) fiber.Handler { (cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) { // Lock entry mux.Lock() - e = manager.get(key) + e = manager.get(c.Context(), key) e.currHits-- remaining++ - manager.set(key, e, cfg.Expiration) + manager.set(c.Context(), key, e, cfg.Expiration) // Unlock entry mux.Unlock() } diff --git a/middleware/limiter/limiter_sliding.go b/middleware/limiter/limiter_sliding.go index 7f49863d7a1..eb2aa7b9db7 100644 --- a/middleware/limiter/limiter_sliding.go +++ b/middleware/limiter/limiter_sliding.go @@ -41,7 +41,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler { mux.Lock() // Get entry from pool and release when finished - e := manager.get(key) + e := manager.get(c.Context(), key) // Get timestamp ts := uint64(atomic.LoadUint32(&utils.Timestamp)) @@ -95,7 +95,7 @@ func (SlidingWindow) New(cfg Config) fiber.Handler { // we add the expiration to the duration. // Otherwise after the end of "sample window", attackers could launch // a new request with the full window length. - manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second) + manager.set(c.Context(), key, e, time.Duration(resetInSec+expiration)*time.Second) // Unlock entry mux.Unlock() diff --git a/middleware/limiter/manager.go b/middleware/limiter/manager.go index 68a785a7c27..33e0df992be 100644 --- a/middleware/limiter/manager.go +++ b/middleware/limiter/manager.go @@ -1,6 +1,7 @@ package limiter import ( + "context" "sync" "time" @@ -58,59 +59,59 @@ func (m *manager) release(e *item) { } // get data from storage or memory -func (m *manager) get(key string) (it *item) { +func (m *manager) get(ctx context.Context, key string) (it *item) { if m.storage != nil { it = m.acquire() - if raw, _ := m.storage.Get(key); raw != nil { + if raw, _ := m.storage.Get(ctx, key); raw != nil { if _, err := it.UnmarshalMsg(raw); err != nil { return } } return } - if it, _ = m.memory.Get(key).(*item); it == nil { + if it, _ = m.memory.Get(ctx, key).(*item); it == nil { it = m.acquire() } return } // get raw data from storage or memory -func (m *manager) getRaw(key string) (raw []byte) { +func (m *manager) getRaw(ctx context.Context, key string) (raw []byte) { if m.storage != nil { - raw, _ = m.storage.Get(key) + raw, _ = m.storage.Get(ctx, key) } else { - raw, _ = m.memory.Get(key).([]byte) + raw, _ = m.memory.Get(ctx, key).([]byte) } return } // set data to storage or memory -func (m *manager) set(key string, it *item, exp time.Duration) { +func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) { if m.storage != nil { if raw, err := it.MarshalMsg(nil); err == nil { - _ = m.storage.Set(key, raw, exp) + _ = m.storage.Set(ctx, key, raw, exp) } // we can release data because it's serialized to database m.release(it) } else { - m.memory.Set(key, it, exp) + m.memory.Set(ctx, key, it, exp) } } // set data to storage or memory -func (m *manager) setRaw(key string, raw []byte, exp time.Duration) { +func (m *manager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) { if m.storage != nil { - _ = m.storage.Set(key, raw, exp) + _ = m.storage.Set(ctx, key, raw, exp) } else { - m.memory.Set(key, raw, exp) + m.memory.Set(ctx, key, raw, exp) } } // delete data from storage or memory -func (m *manager) delete(key string) { +func (m *manager) delete(ctx context.Context, key string) { if m.storage != nil { - _ = m.storage.Delete(key) + _ = m.storage.Delete(ctx, key) } else { - m.memory.Delete(key) + m.memory.Delete(ctx, key) } } diff --git a/middleware/session/session.go b/middleware/session/session.go index 33be2a1a7d9..263fcd8f5f0 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -2,6 +2,7 @@ package session import ( "bytes" + "context" "encoding/gob" "sync" "time" @@ -91,7 +92,7 @@ func (s *Session) Delete(key string) { } // Destroy will delete the session from Storage and expire session cookie -func (s *Session) Destroy() error { +func (s *Session) Destroy(ctx context.Context) error { // Better safe than sorry if s.data == nil { return nil @@ -101,7 +102,7 @@ func (s *Session) Destroy() error { s.data.Reset() // Use external Storage if exist - if err := s.config.Storage.Delete(s.id); err != nil { + if err := s.config.Storage.Delete(ctx, s.id); err != nil { return err } @@ -111,9 +112,9 @@ func (s *Session) Destroy() error { } // Regenerate generates a new session id and delete the old one from Storage -func (s *Session) Regenerate() error { +func (s *Session) Regenerate(ctx context.Context) error { // Delete old id from storage - if err := s.config.Storage.Delete(s.id); err != nil { + if err := s.config.Storage.Delete(ctx, s.id); err != nil { return err } @@ -133,7 +134,7 @@ func (s *Session) refresh() { } // Save will update the storage and client cookie -func (s *Session) Save() error { +func (s *Session) Save(ctx context.Context) error { // Better safe than sorry if s.data == nil { return nil @@ -161,7 +162,7 @@ func (s *Session) Save() error { copy(encodedBytes, s.byteBuffer.Bytes()) // pass copied bytes with session id to provider - if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil { + if err := s.config.Storage.Set(ctx, s.id, encodedBytes, s.exp); err != nil { return err } diff --git a/middleware/session/session_test.go b/middleware/session/session_test.go index 489ea46edff..11c7968449e 100644 --- a/middleware/session/session_test.go +++ b/middleware/session/session_test.go @@ -1,6 +1,7 @@ package session import ( + "context" "testing" "time" @@ -66,7 +67,7 @@ func Test_Session(t *testing.T) { utils.AssertEqual(t, "123", id) // save the old session first - err = sess.Save() + err = sess.Save(context.TODO()) utils.AssertEqual(t, nil, err) // requesting entirely new context to prevent falsy tests @@ -169,7 +170,7 @@ func Test_Session_Types(t *testing.T) { sess.Set("vcomplex128", vcomplex128) // save session - err = sess.Save() + err = sess.Save(context.TODO()) utils.AssertEqual(t, nil, err) // get session @@ -218,10 +219,10 @@ func Test_Session_Store_Reset(t *testing.T) { // set value & save sess.Set("hello", "world") ctx.Request().Header.SetCookie(store.sessionName, sess.ID()) - utils.AssertEqual(t, nil, sess.Save()) + utils.AssertEqual(t, nil, sess.Save(context.TODO())) // reset store - utils.AssertEqual(t, nil, store.Reset()) + utils.AssertEqual(t, nil, store.Reset(context.TODO())) // make sure the session is recreated sess, _ = store.Get(ctx) @@ -247,7 +248,7 @@ func Test_Session_Save(t *testing.T) { sess.Set("name", "john") // save session - err := sess.Save() + err := sess.Save(context.TODO()) utils.AssertEqual(t, nil, err) }) @@ -267,7 +268,7 @@ func Test_Session_Save(t *testing.T) { sess.Set("name", "john") // save session - err := sess.Save() + err := sess.Save(context.TODO()) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Response().Header.Peek(store.sessionName))) utils.AssertEqual(t, store.getSessionID(ctx), string(ctx.Request().Header.Peek(store.sessionName))) @@ -295,7 +296,7 @@ func Test_Session_Save_Expiration(t *testing.T) { sess.SetExpiry(time.Second * 5) // save session - err := sess.Save() + err := sess.Save(context.TODO()) utils.AssertEqual(t, nil, err) // here you need to get the old session yet @@ -328,7 +329,7 @@ func Test_Session_Reset(t *testing.T) { sess, _ := store.Get(ctx) sess.Set("name", "fenny") - utils.AssertEqual(t, nil, sess.Destroy()) + utils.AssertEqual(t, nil, sess.Destroy(context.TODO())) name := sess.Get("name") utils.AssertEqual(t, nil, name) }) @@ -349,10 +350,10 @@ func Test_Session_Reset(t *testing.T) { // set value & save sess.Set("name", "fenny") - utils.AssertEqual(t, nil, sess.Save()) + utils.AssertEqual(t, nil, sess.Save(context.TODO())) sess, _ = store.Get(ctx) - err := sess.Destroy() + err := sess.Destroy(context.TODO()) utils.AssertEqual(t, nil, err) utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName))) utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName))) @@ -384,7 +385,7 @@ func Test_Session_Cookie(t *testing.T) { // get session sess, _ := store.Get(ctx) - utils.AssertEqual(t, nil, sess.Save()) + utils.AssertEqual(t, nil, sess.Save(context.TODO())) // cookie should be set on Save ( even if empty data ) utils.AssertEqual(t, 84, len(ctx.Response().Header.PeekCookie(store.sessionName))) @@ -404,7 +405,7 @@ func Test_Session_Cookie_In_Response(t *testing.T) { sess, _ := store.Get(ctx) sess.Set("id", "1") utils.AssertEqual(t, true, sess.Fresh()) - utils.AssertEqual(t, nil, sess.Save()) + utils.AssertEqual(t, nil, sess.Save(context.TODO())) sess, _ = store.Get(ctx) sess.Set("name", "john") @@ -429,12 +430,12 @@ func Test_Session_Deletes_Single_Key(t *testing.T) { ctx.Request().Header.SetCookie(store.sessionName, sess.ID()) sess.Set("id", "1") - utils.AssertEqual(t, nil, sess.Save()) + utils.AssertEqual(t, nil, sess.Save(context.TODO())) sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) sess.Delete("id") - utils.AssertEqual(t, nil, sess.Save()) + utils.AssertEqual(t, nil, sess.Save(context.TODO())) sess, err = store.Get(ctx) utils.AssertEqual(t, nil, err) @@ -463,7 +464,7 @@ func Test_Session_Regenerate(t *testing.T) { originalSessionUUIDString = freshSession.ID() - err = freshSession.Save() + err = freshSession.Save(context.TODO()) utils.AssertEqual(t, nil, err) // set cookie @@ -474,7 +475,7 @@ func Test_Session_Regenerate(t *testing.T) { utils.AssertEqual(t, nil, err) utils.AssertEqual(t, false, acquiredSession.Fresh()) - err = acquiredSession.Regenerate() + err = acquiredSession.Regenerate(context.TODO()) utils.AssertEqual(t, nil, err) if acquiredSession.ID() == originalSessionUUIDString { @@ -499,7 +500,7 @@ func Benchmark_Session(b *testing.B) { for n := 0; n < b.N; n++ { sess, _ := store.Get(c) sess.Set("john", "doe") - err = sess.Save() + err = sess.Save(context.TODO()) } utils.AssertEqual(b, nil, err) @@ -514,7 +515,7 @@ func Benchmark_Session(b *testing.B) { for n := 0; n < b.N; n++ { sess, _ := store.Get(c) sess.Set("john", "doe") - err = sess.Save() + err = sess.Save(context.TODO()) } utils.AssertEqual(b, nil, err) diff --git a/middleware/session/store.go b/middleware/session/store.go index cc8a80a0cf5..d5dd8deebfa 100644 --- a/middleware/session/store.go +++ b/middleware/session/store.go @@ -1,6 +1,7 @@ package session import ( + "context" "encoding/gob" "sync" @@ -65,7 +66,7 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) { // Fetch existing data if loadData { - raw, err := s.Storage.Get(id) + raw, err := s.Storage.Get(c.Context(), id) // Unmarshal if we found data if raw != nil && err == nil { mux.Lock() @@ -135,6 +136,6 @@ func (s *Store) responseCookies(c *fiber.Ctx) (string, error) { } // Reset will delete all session from the storage -func (s *Store) Reset() error { - return s.Storage.Reset() +func (s *Store) Reset(ctx context.Context) error { + return s.Storage.Reset(ctx) }