diff --git a/README.md b/README.md index eeb3476..ae64654 100644 --- a/README.md +++ b/README.md @@ -9,39 +9,41 @@ Most session libraries are highly opinionated and hard-wired to work with `net/h ## Features 1. Framework/network library agnostic. 2. Simple API and with support for primitive data types. Complex types can be stored using own encoding/decoding. -3. Pre-built redis/postgres/in-memory stores that can be separately installed. +3. Pre-built redis/postgres/in-memory/securecookie stores that can be separately installed. 4. Multiple session instances with custom handlers and different backend stores. ## Installation Install `simplesessions` and all [available stores](/stores). ```shell -go get -u github.com/vividvilla/simplesessions +go get -u github.com/vividvilla/simplesessions/v3 -# Install the requrired store: memory|goredis|redis|postgres -go get -u github.com/vividvilla/simplesessions/stores/goredis +# Install the requrired store: memory|redis|postgres|securecookie +go get -u github.com/vividvilla/simplesessions/v3/stores/redis +go get -u github.com/vividvilla/simplesessions/v3/stores/postgres ``` # Stores Sessions can be stored to any backend by implementing the [store](/store.go) interface. The following stores are bundled. -* [in-memory](/stores/memory) * [redis](/stores/redis) +* [postgres](/stores/postgres) +* [in-memory](/stores/memory) * [secure cookie](/stores/securecookie) # Usage Check the [examples](/examples) directory for complete examples. ## Connecting a store -Stores can be registered to a session instance by using `Use` method. +Stores can be registered to a session instance by using `Use` method. Check individual [Stores](#stores) docs for more details. ```go sess := simplesessions.New(simplesessions.Options{}) -sess.UseStore(memory.New()) +sess.UseStore(store.New()) ``` ## Connecting an HTTP handler -Any HTTP library can be connected to simplesessions by registering the `RegisterGetCookie()` and `RegisterSetCookie()` callbacks. The below example shows a simple `net/http` usecase. Another example showing `fasthttp` can be found [here](/examples). +Any HTTP library can be connected to simplesessions by registering the get and set cookie hooks using `SetCookieHooks()`. The below example shows a simple `net/http` usecase. Another example showing `fasthttp` can be found [here](/examples). ```go var sessMan *simplesessions.Manager @@ -81,44 +83,58 @@ func setCookie(cookie *http.Cookie, w interface{}) error { func handler(w http.ResponseWriter, r *http.Request) { // Use method `Acquire` to acquire a session before you access the session. // Acquire takes read, write interface and context respectively. - // Read interface sent to callback registered with `RegisterGetCookie` - // and write interface is sent to callback registered with `RegisterWriteCookie` + // Read interface sent to callback registered with get cookie hook + // and write interface is sent to callback registered with write cookie hook + // set using `SetCookieHooks()` method. + // // Optionally `context` can be sent which is usually request context where acquire // session will get previously loaded session. This is useful if you have multiple // middlewares accessing sessions. New sessions will be created in first middleware which // does `Acquire` and will be reused in other places. - sess, err := sessMan.Acquire(r, w, nil) - - // Use 'Set` and `Commit` to set a field for session. - // 'Set` ideally doesn't persist the value to store unless method `Commit` is called. - // But note that its up to the store you are using to decide to - // persist data only on `commit` or persist on `Set` itself. - // Stores like redis, db etc should persist on `Commit` while in-memory does on `Set`. - // No matter what store you use its better to explicitly - // call `Commit` method when you set all the values. + // + // If `Options.EnableAutoCreate` is set to True then if session doesn't exist it will + // be immediately created and returned. Bydefault its set to False so if session doesn't + // exist then `ErrInvalidSession` error is returned. + sess, err := sessMan.Acquire(nil, r, w) + + // If session doesn't exist then create new session. + // In a traditional login flow you can create a new session once user completes the login flow. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMan.NewSession(r, w) + } + + // Use 'Set` or `SetMulti` to set a field for session. err = sess.Set("somekey", "somevalue") - err = sess.Set("someotherkey", 10) - err = sess.Commit() + err = sess.SetMulti(map[string]interface{}{ + "k1": "v1", + "k2": "v2", + }) // Use `Get` method to get a field from current session. The result will be an interface // so you can use helper methods like // `String', `Int`, `Int64`, `UInt64`, `Float64`, `Bytes`, `Bool`. val, err := sess.String(sess.Get("somekey")) + fmt.Println("val=", val) // Use `GetAll` to get map of all fields from session. // The result is map of string and interface you can use helper methods to type cast it. - val, err := sess.GetAll() + all, err := sess.GetAll() + fmt.Println("all=", all) // Use `GetMulti` to get values for given fields from session. // The result is map of string and interface you can use helper methods to type cast it. // If key is not there then store should ideally send `nil` value for given key. - val, err := sess.GetMulti("somekey", "someotherkey") + vals, err := sess.GetMulti("somekey", "someotherkey") + fmt.Println("vals=", vals) // Use `Delete` to delete a field from session. - err := sess.Delete("somekey") + err = sess.Delete("somekey") + + // Use `Clear` to empty the session but to keep the session alive. + err = sess.Clear() - // Use `Clear` to clear session from store. - err := sess.Clear() + // Use `Destroy` to clear session from store and cookie. + err = sess.Destroy() fmt.Fprintf(w, "success") } @@ -126,16 +142,45 @@ func handler(w http.ResponseWriter, r *http.Request) { func main() { // Create a session manager with custom options like cookie name, // cookie domain, is secure cookie etc. Check `Options` struct for more options. - sessMan := simplesessions.New(simplesessions.Options{}) + sessMan := simplesessions.New(simplesessions.Options{ + // If set to true then `Acquire()` method will create new session instead of throwing + // `ErrInvalidSession` when the session doesn't exist. By default its set to false. + EnableAutoCreate: false, + Cookie: simplesessions.CookieOptions{ + // Name sets http cookie name. This is also sent as cookie name in `GetCookie` callback. + Name: "session", + // Domain sets hostname for the cookie. Domain specifies allowed hosts to receive the cookie. + Domain: "example.com", + // Path sets path for the cookie. Path indicates a URL path that must exist in the requested URL in order to send the cookie header. + Path: "/", + // IsSecure marks the cookie as secure cookie (only sent in HTTPS). + IsSecure: true, + // IsHTTPOnly marks the cookie as http only cookie. JS won't be able to access the cookie so prevents XSS attacks. + IsHTTPOnly: true, + // SameSite sets allows you to declare if your cookie should be restricted to a first-party or same-site context. + SameSite: http.SameSiteDefaultMode, + // Expires sets absolute expiration date and time for the cookie. + // If both Expires and MaxAge are sent then MaxAge takes precedence over Expires. + // Cookies without a Max-age or Expires attribute – are deleted when the current session ends + // and some browsers use session restoring when restarting. This can cause session cookies to last indefinitely. + Expires: time.Now().Add(time.Hour * 24), + // Sets the cookie's expiration in seconds from the current time, internally its rounder off to nearest seconds. + // If both Expires and MaxAge are sent then MaxAge takes precedence over Expires. + // Cookies without a Max-age or Expires attribute – are deleted when the current session ends + // and some browsers use session restoring when restarting. This can cause session cookies to last indefinitely. + MaxAge: time.Hour * 24, + }, + }) + // Create a new store instance and attach to session manager sessMan.UseStore(memory.New()) - // Register callbacks for read and write cookie + // Register callbacks for read and write cookie. // Get cookie callback should get cookie based on cookie name and // sent back in net/http cookie format. - sessMan.RegisterGetCookie(getCookie) // Set cookie callback should set cookie it received for received cookie name. - sessMan.RegisterSetCookie(setCookie) + sessMan.SetCookieHooks(getCookie, setCookie) - http.HandleFunc("/set", handler) + // Initialize the handler. + http.HandleFunc("/", handler) } ``` diff --git a/TODO b/TODO deleted file mode 100644 index 8a93ec3..0000000 --- a/TODO +++ /dev/null @@ -1,14 +0,0 @@ -[x] Redis store -[x] Helper methods for type assertion -[x] Provision in store to call `SetCookie` -[x] Tests for session manager -[x] Tests for session -[x] Tests for Redis store -[x] In-memory store -[x] Tests for in-memory store -[x] Net http example -[x] Delete method for deleting individual field in session -[x] Fasthttp examples -[x] Secure cookie store with optional encoding and decoding -[x] Tests for Secure cookie store -[ ] Benchmark comparing with gosessions, gorilla sessions and scs diff --git a/conv/conv.go b/conv/conv.go deleted file mode 100644 index acb0f94..0000000 --- a/conv/conv.go +++ /dev/null @@ -1,187 +0,0 @@ -// Package conv to help type assertions and conversions. -package conv - -import ( - "strconv" -) - -var ( - // Error codes for store errors. This should match the codes - // defined in the /simplesessions package exactly. - ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} - ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} -) - -type Err struct { - code int - msg string -} - -func (e *Err) Error() string { - return e.msg -} - -func (e *Err) Code() int { - return e.code -} - -// Int converts interface to integer. -func Int(r interface{}, err error) (int, error) { - if err != nil { - return 0, err - } - - switch r := r.(type) { - case int: - return r, nil - case int64: - x := int(r) - if int64(x) != r { - return 0, strconv.ErrRange - } - return x, nil - case []byte: - n, err := strconv.ParseInt(string(r), 10, 0) - return int(n), err - case string: - n, err := strconv.ParseInt(r, 10, 0) - return int(n), err - case nil: - return 0, ErrNil - } - - return 0, ErrAssertType -} - -// Int64 converts interface to Int64. -func Int64(r interface{}, err error) (int64, error) { - if err != nil { - return 0, err - } - - switch r := r.(type) { - case int: - return int64(r), nil - case int64: - return r, nil - case []byte: - n, err := strconv.ParseInt(string(r), 10, 64) - return n, err - case string: - n, err := strconv.ParseInt(r, 10, 64) - return n, err - case nil: - return 0, ErrNil - } - - return 0, ErrAssertType -} - -// UInt64 converts interface to UInt64. -func UInt64(r interface{}, err error) (uint64, error) { - if err != nil { - return 0, err - } - - switch r := r.(type) { - case uint64: - return r, err - case int: - if r < 0 { - return 0, ErrAssertType - } - return uint64(r), nil - case int64: - if r < 0 { - return 0, ErrAssertType - } - return uint64(r), nil - case []byte: - n, err := strconv.ParseUint(string(r), 10, 64) - return n, err - case string: - n, err := strconv.ParseUint(r, 10, 64) - return n, err - case nil: - return 0, ErrNil - } - - return 0, ErrAssertType -} - -// Float64 converts interface to Float64. -func Float64(r interface{}, err error) (float64, error) { - if err != nil { - return 0, err - } - switch r := r.(type) { - case float64: - return r, err - case []byte: - n, err := strconv.ParseFloat(string(r), 64) - return n, err - case string: - n, err := strconv.ParseFloat(r, 64) - return n, err - case nil: - return 0, ErrNil - } - return 0, ErrAssertType -} - -// String converts interface to String. -func String(r interface{}, err error) (string, error) { - if err != nil { - return "", err - } - switch r := r.(type) { - case []byte: - return string(r), nil - case string: - return r, nil - case nil: - return "", ErrNil - } - return "", ErrAssertType -} - -// Bytes converts interface to Bytes. -func Bytes(r interface{}, err error) ([]byte, error) { - if err != nil { - return nil, err - } - switch r := r.(type) { - case []byte: - return r, nil - case string: - return []byte(r), nil - case nil: - return nil, ErrNil - } - return nil, ErrAssertType -} - -// Bool converts interface to Bool. -func Bool(r interface{}, err error) (bool, error) { - if err != nil { - return false, err - } - switch r := r.(type) { - case bool: - return r, err - // Very common in redis to reply int64 with 0 for bool flag. - case int: - return r != 0, nil - case int64: - return r != 0, nil - case []byte: - return strconv.ParseBool(string(r)) - case string: - return strconv.ParseBool(r) - case nil: - return false, ErrNil - } - return false, ErrAssertType -} diff --git a/conv/conv_test.go b/conv/conv_test.go deleted file mode 100644 index 2ce0a9b..0000000 --- a/conv/conv_test.go +++ /dev/null @@ -1,284 +0,0 @@ -package conv - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" -) - -var ( - errTest = errors.New("test error") -) - -func TestInt(t *testing.T) { - assert := assert.New(t) - - v, err := Int(1, nil) - assert.NoError(err) - assert.Equal(1, v) - - v, err = Int("1", nil) - assert.NoError(err) - assert.Equal(1, v) - - v, err = Int([]byte("1"), nil) - assert.NoError(err) - assert.Equal(1, v) - - var tVal int64 = 1 - v, err = Int(tVal, nil) - assert.NoError(err) - assert.Equal(1, v) - - var tVal1 interface{} = 1 - v, err = Int(tVal1, nil) - assert.NoError(err) - assert.Equal(1, v) - - // Test if ErrNil is returned if value is nil. - v, err = Int(nil, nil) - assert.Error(err, ErrNil) - assert.Equal(0, v) - - // Test if custom error sent is returned. - v, err = Int(nil, errTest) - assert.Error(err, errTest) - assert.Equal(0, v) - - // Test invalid assert error. - v, err = Int(10.1112, nil) - assert.Error(err, ErrAssertType) - assert.Equal(0, v) -} - -func TestInt64(t *testing.T) { - assert := assert.New(t) - - v, err := Int64(int64(1), nil) - assert.NoError(err) - assert.Equal(int64(1), v) - - v, err = Int64("1", nil) - assert.NoError(err) - assert.Equal(int64(1), v) - - v, err = Int64([]byte("1"), nil) - assert.NoError(err) - assert.Equal(int64(1), v) - - var tVal interface{} = 1 - v, err = Int64(tVal, nil) - assert.NoError(err) - assert.Equal(int64(1), v) - - // Test if ErrNil is returned if value is nil. - v, err = Int64(nil, nil) - assert.Error(err, ErrNil) - assert.Equal(int64(0), v) - - // Test if custom error sent is returned. - v, err = Int64(nil, errTest) - assert.Error(err, errTest) - assert.Equal(int64(0), v) - - // Test invalid assert error. - v, err = Int64(10.1112, nil) - assert.Error(err, ErrAssertType) - assert.Equal(int64(0), v) -} - -func TestUInt64(t *testing.T) { - assert := assert.New(t) - - v, err := UInt64(uint64(1), nil) - assert.NoError(err) - assert.Equal(uint64(1), v) - - v, err = UInt64("1", nil) - assert.NoError(err) - assert.Equal(uint64(1), v) - - v, err = UInt64([]byte("1"), nil) - assert.NoError(err) - assert.Equal(uint64(1), v) - - var tVal interface{} = 1 - v, err = UInt64(tVal, nil) - assert.NoError(err) - assert.Equal(uint64(1), v) - - // Test if ErrNil is returned if value is nil. - v, err = UInt64(nil, nil) - assert.Error(err, ErrNil) - assert.Equal(uint64(0), v) - - // Test if custom error sent is returned. - v, err = UInt64(nil, errTest) - assert.Error(err, errTest) - assert.Equal(uint64(0), v) - - // Test invalid assert error. - v, err = UInt64(10.1112, nil) - assert.Error(err, ErrAssertType) - assert.Equal(uint64(0), v) -} - -func TestFloat64(t *testing.T) { - assert := assert.New(t) - - v, err := Float64(float64(1.11), nil) - assert.NoError(err) - assert.Equal(float64(1.11), v) - - v, err = Float64("1.11", nil) - assert.NoError(err) - assert.Equal(float64(1.11), v) - - v, err = Float64([]byte("1.11"), nil) - assert.NoError(err) - assert.Equal(float64(1.11), v) - - var tVal float64 = 1.11 - v, err = Float64(tVal, nil) - assert.NoError(err) - assert.Equal(float64(1.11), v) - - // Test if ErrNil is returned if value is nil. - v, err = Float64(nil, nil) - assert.Error(err, ErrNil) - assert.Equal(float64(0), v) - - // Test if custom error sent is returned. - v, err = Float64(nil, errTest) - assert.Error(err, errTest) - assert.Equal(float64(0), v) - - // Test invalid assert error. - v, err = Float64("abc", nil) - assert.Error(err, ErrAssertType) - assert.Equal(float64(0), v) -} - -func TestString(t *testing.T) { - assert := assert.New(t) - - v, err := String("abc", nil) - assert.NoError(err) - assert.Equal("abc", v) - - v, err = String([]byte("abc"), nil) - assert.NoError(err) - assert.Equal("abc", v) - - var tVal interface{} = "abc" - v, err = String(tVal, nil) - assert.NoError(err) - assert.Equal("abc", v) - - // Test if ErrNil is returned if value is nil. - v, err = String(nil, nil) - assert.Error(err, ErrNil) - assert.Equal("", v) - - // Test if custom error sent is returned. - v, err = String(nil, errTest) - assert.Error(err, errTest) - assert.Equal("", v) - - // Test invalid assert error. - v, err = String(10.1112, nil) - assert.Error(err, ErrAssertType) - assert.Equal("", v) -} - -func TestBytes(t *testing.T) { - assert := assert.New(t) - - v, err := Bytes("abc", nil) - assert.NoError(err) - assert.Equal([]byte("abc"), v) - - v, err = Bytes([]byte("abc"), nil) - assert.NoError(err) - assert.Equal([]byte("abc"), v) - - var tVal interface{} = "abc" - v, err = Bytes(tVal, nil) - assert.NoError(err) - assert.Equal([]byte("abc"), v) - - // Test if ErrNil is returned if value is nil. - v, err = Bytes(nil, nil) - assert.Error(err, ErrNil) - assert.Equal([]byte(nil), v) - - // Test if custom error sent is returned. - v, err = Bytes(nil, errTest) - assert.Error(err, errTest) - assert.Equal([]byte(nil), v) - - // Test invalid assert error. - v, err = Bytes(10.1112, nil) - assert.Error(err, ErrAssertType) - assert.Equal([]byte(nil), v) -} - -func TestBool(t *testing.T) { - assert := assert.New(t) - - v, err := Bool(true, nil) - assert.NoError(err) - assert.Equal(true, v) - - v, err = Bool(false, nil) - assert.NoError(err) - assert.Equal(false, v) - - v, err = Bool(0, nil) - assert.NoError(err) - assert.Equal(false, v) - - v, err = Bool(1, nil) - assert.NoError(err) - assert.Equal(true, v) - - v, err = Bool(int64(0), nil) - assert.NoError(err) - assert.Equal(false, v) - - v, err = Bool(int64(1), nil) - assert.NoError(err) - assert.Equal(true, v) - - v, err = Bool([]byte("true"), nil) - assert.NoError(err) - assert.Equal(true, v) - - v, err = Bool([]byte("false"), nil) - assert.NoError(err) - assert.Equal(false, v) - - v, err = Bool("true", nil) - assert.NoError(err) - assert.Equal(true, v) - - v, err = Bool("false", nil) - assert.NoError(err) - assert.Equal(false, v) - - // Test if ErrNil is returned if value is nil. - v, err = Bool(nil, nil) - assert.Error(err, ErrNil) - assert.Equal(false, v) - - // Test if custom error sent is returned. - v, err = Bool(nil, errTest) - assert.Error(err, errTest) - assert.Equal(false, v) - - // Test invalid assert error. - v, err = Bool(10.1112, nil) - assert.Error(err, ErrAssertType) - assert.Equal(false, v) -} diff --git a/conv/go.mod b/conv/go.mod deleted file mode 100644 index cab6c20..0000000 --- a/conv/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module github.com/vividvilla/simplesessions/conv - -go 1.14 - -require github.com/stretchr/testify v1.9.0 diff --git a/examples/fastglue-goredis/go.mod b/examples/fastglue-goredis/go.mod deleted file mode 100644 index 7c4f562..0000000 --- a/examples/fastglue-goredis/go.mod +++ /dev/null @@ -1,13 +0,0 @@ -module github.com/vividvilla/simplesessions/examples/fastglue-goredis - -go 1.16 - -require ( - github.com/fasthttp/router v1.5.0 // indirect - github.com/klauspost/compress v1.17.8 // indirect - github.com/redis/go-redis/v9 v9.5.1 - github.com/valyala/fasthttp v1.52.0 - github.com/vividvilla/simplesessions/stores/goredis/v9 v9.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 - github.com/zerodha/fastglue v1.8.0 -) diff --git a/examples/fastglue-goredis/main.go b/examples/fastglue-redis/main.go similarity index 80% rename from examples/fastglue-goredis/main.go rename to examples/fastglue-redis/main.go index 6475d68..4066d6e 100644 --- a/examples/fastglue-goredis/main.go +++ b/examples/fastglue-redis/main.go @@ -8,8 +8,8 @@ import ( "github.com/redis/go-redis/v9" "github.com/valyala/fasthttp" - redisstore "github.com/vividvilla/simplesessions/stores/goredis/v9" - "github.com/vividvilla/simplesessions/v2" + redisstore "github.com/vividvilla/simplesessions/stores/redis/v3" + "github.com/vividvilla/simplesessions/v3" "github.com/zerodha/fastglue" ) @@ -18,9 +18,9 @@ const ( ) var ( - sessionManager *simplesessions.Manager - testKey = "question" - testValue = 42 + sessMgr *simplesessions.Manager + testKey = "question" + testValue = 42 ) func initRedisGo(address, password string) *redis.Client { @@ -43,8 +43,11 @@ func initServer(name string, timeout int) *fasthttp.Server { } func setHandler(r *fastglue.Request) error { - - sess, err := sessionManager.Acquire(r.RequestCtx, r.RequestCtx, nil) + sess, err := sessMgr.Acquire(nil, r.RequestCtx, r.RequestCtx) + // Create new session if it doesn't exist. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMgr.NewSession(r.RequestCtx, r.RequestCtx) + } if err != nil { return r.SendErrorEnvelope(fasthttp.StatusInternalServerError, err.Error(), nil, GeneralError) } @@ -54,15 +57,11 @@ func setHandler(r *fastglue.Request) error { return r.SendErrorEnvelope(fasthttp.StatusInternalServerError, err.Error(), nil, GeneralError) } - if err = sess.Commit(); err != nil { - return r.SendErrorEnvelope(fasthttp.StatusInternalServerError, err.Error(), nil, GeneralError) - } - return r.SendEnvelope("success") } func getHandler(r *fastglue.Request) error { - sess, err := sessionManager.Acquire(r.RequestCtx, r.RequestCtx, nil) + sess, err := sessMgr.Acquire(nil, r.RequestCtx, r.RequestCtx) if err != nil { return r.SendErrorEnvelope(fasthttp.StatusInternalServerError, err.Error(), nil, GeneralError) } @@ -128,10 +127,9 @@ func main() { ctx := context.Background() store := redisstore.New(ctx, rc) - sessionManager = simplesessions.New(simplesessions.Options{}) - sessionManager.UseStore(store) - sessionManager.RegisterGetCookie(getCookie) - sessionManager.RegisterSetCookie(setCookie) + sessMgr = simplesessions.New(simplesessions.Options{}) + sessMgr.UseStore(store) + sessMgr.SetCookieHooks(getCookie, setCookie) g := fastglue.New() g.GET("/get", getHandler) @@ -139,7 +137,7 @@ func main() { // 5s read/write timeout server := initServer("go-redis", 5) - if err := g.ListenAndServe(":3000", "", server); err != nil { + if err := g.ListenAndServe(":1111", "", server); err != nil { log.Fatal(err) } } diff --git a/examples/fasthttp-inmemory/go.mod b/examples/fasthttp-inmemory/go.mod deleted file mode 100644 index e66abb4..0000000 --- a/examples/fasthttp-inmemory/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module github.com/vividvilla/simplesessions/examples/fasthttp-inmemory - -require ( - github.com/klauspost/compress v1.4.0 // indirect - github.com/klauspost/cpuid v0.0.0-20180405133222-e7e905edc00e // indirect - github.com/valyala/bytebufferpool v0.0.0-20160817181652-e746df99fe4a // indirect - github.com/valyala/fasthttp v0.0.0-20180901052036-d7688109a57b - github.com/vividvilla/simplesessions v0.0.1 - github.com/vividvilla/simplesessions/stores/memory v0.0.0-20180905073812-64bb2453ba8a -) diff --git a/examples/fasthttp-inmemory/main.go b/examples/fasthttp-inmemory/main.go index c7293d9..ba50f17 100644 --- a/examples/fasthttp-inmemory/main.go +++ b/examples/fasthttp-inmemory/main.go @@ -5,31 +5,31 @@ import ( "net/http" "github.com/valyala/fasthttp" - "github.com/vividvilla/simplesessions" - "github.com/vividvilla/simplesessions/stores/memory" + "github.com/vividvilla/simplesessions/stores/memory/v3" + "github.com/vividvilla/simplesessions/v3" ) var ( - sessionManager *simplesessions.Manager + sessMgr *simplesessions.Manager testKey = "abc123" testValue = 123456 ) func setHandler(ctx *fasthttp.RequestCtx) { - sess, err := sessionManager.Acquire(ctx, ctx, nil) - if err != nil { - ctx.Error(err.Error(), 500) - return + sess, err := sessMgr.Acquire(nil, ctx, ctx) + // Create new session if it doesn't exist. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMgr.NewSession(ctx, ctx) } - err = sess.Set(testKey, testValue) if err != nil { ctx.Error(err.Error(), 500) return } - if err = sess.Commit(); err != nil { + err = sess.Set(testKey, testValue) + if err != nil { ctx.Error(err.Error(), 500) return } @@ -38,7 +38,7 @@ func setHandler(ctx *fasthttp.RequestCtx) { } func getHandler(ctx *fasthttp.RequestCtx) { - sess, err := sessionManager.Acquire(ctx, ctx, nil) + sess, err := sessMgr.Acquire(nil, ctx, ctx) if err != nil { ctx.Error(err.Error(), 500) return @@ -99,10 +99,9 @@ func setCookie(cookie *http.Cookie, w interface{}) error { } func main() { - sessionManager = simplesessions.New(simplesessions.Options{}) - sessionManager.UseStore(memory.New()) - sessionManager.RegisterGetCookie(getCookie) - sessionManager.RegisterSetCookie(setCookie) + sessMgr = simplesessions.New(simplesessions.Options{}) + sessMgr.UseStore(memory.New()) + sessMgr.SetCookieHooks(getCookie, setCookie) m := func(ctx *fasthttp.RequestCtx) { switch string(ctx.Path()) { diff --git a/examples/fasthttp-redis/go.mod b/examples/fasthttp-redis/go.mod deleted file mode 100644 index 84fe957..0000000 --- a/examples/fasthttp-redis/go.mod +++ /dev/null @@ -1,16 +0,0 @@ -module github.com/vividvilla/simplesessions/examples/fasthttp-redis - -go 1.14 - -require ( - github.com/gomodule/redigo v2.0.0+incompatible - github.com/valyala/fasthttp v1.52.0 - github.com/vividvilla/simplesessions/stores/redis/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 -) - -require ( - github.com/andybalholm/brotli v1.1.0 // indirect - github.com/klauspost/compress v1.17.6 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect -) diff --git a/examples/fasthttp-redis/main.go b/examples/fasthttp-redis/main.go index 0ac0035..617af78 100644 --- a/examples/fasthttp-redis/main.go +++ b/examples/fasthttp-redis/main.go @@ -1,37 +1,39 @@ package main import ( + "context" "fmt" + "log" "net/http" "time" - "github.com/gomodule/redigo/redis" + "github.com/redis/go-redis/v9" "github.com/valyala/fasthttp" - redisstore "github.com/vividvilla/simplesessions/stores/redis/v2" - "github.com/vividvilla/simplesessions/v2" + redisstore "github.com/vividvilla/simplesessions/stores/redis/v3" + "github.com/vividvilla/simplesessions/v3" ) var ( - sessionManager *simplesessions.Manager + sessMgr *simplesessions.Manager testKey = "abc123" testValue = 123456 ) func setHandler(ctx *fasthttp.RequestCtx) { - sess, err := sessionManager.Acquire(ctx, ctx, nil) - if err != nil { - ctx.Error(err.Error(), 500) - return + sess, err := sessMgr.Acquire(nil, ctx, ctx) + // Create new session if it doesn't exist. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMgr.NewSession(ctx, ctx) } - err = sess.Set(testKey, testValue) if err != nil { ctx.Error(err.Error(), 500) return } - if err = sess.Commit(); err != nil { + err = sess.Set(testKey, testValue) + if err != nil { ctx.Error(err.Error(), 500) return } @@ -40,7 +42,7 @@ func setHandler(ctx *fasthttp.RequestCtx) { } func getHandler(ctx *fasthttp.RequestCtx) { - sess, err := sessionManager.Acquire(ctx, ctx, nil) + sess, err := sessMgr.Acquire(ctx, ctx, nil) if err != nil { ctx.Error(err.Error(), 500) return @@ -100,34 +102,33 @@ func setCookie(cookie *http.Cookie, w interface{}) error { return nil } -func getRedisPool(address string, password string, maxActive int, maxIdle int, timeout time.Duration) *redis.Pool { - return &redis.Pool{ - Wait: true, - MaxActive: maxActive, - MaxIdle: maxIdle, - Dial: func() (redis.Conn, error) { - c, err := redis.Dial( - "tcp", - address, - redis.DialPassword(password), - redis.DialConnectTimeout(timeout), - redis.DialReadTimeout(timeout), - redis.DialWriteTimeout(timeout), - ) - - return c, err - }, +func getRedisPool() redis.UniversalClient { + o := &redis.Options{ + Addr: "localhost:6379", + Username: "", + Password: "", + DialTimeout: time.Second * 3, + DB: 0, + } + + var ( + ctx = context.TODO() + cl = redis.NewClient(o) + ) + if err := cl.Ping(ctx).Err(); err != nil { + log.Fatalf("error initializing redis: %v", err) } + + return cl } func main() { - rPool := getRedisPool("localhost:6379", "", 10, 10, 1000*time.Millisecond) + rPool := getRedisPool() - sessionManager = simplesessions.New(simplesessions.Options{}) - store := redisstore.New(rPool) - sessionManager.UseStore(store) - sessionManager.RegisterGetCookie(getCookie) - sessionManager.RegisterSetCookie(setCookie) + sessMgr = simplesessions.New(simplesessions.Options{}) + store := redisstore.New(context.TODO(), rPool) + sessMgr.UseStore(store) + sessMgr.SetCookieHooks(getCookie, setCookie) m := func(ctx *fasthttp.RequestCtx) { switch string(ctx.Path()) { diff --git a/examples/go.mod b/examples/go.mod new file mode 100644 index 0000000..41c3e54 --- /dev/null +++ b/examples/go.mod @@ -0,0 +1,35 @@ +module github.com/vividvilla/simplesessions/examples + +go 1.18 + +require ( + github.com/redis/go-redis/v9 v9.5.1 + github.com/valyala/fasthttp v1.44.0 + github.com/vividvilla/simplesessions/stores/memory/v3 v3.0.0 + github.com/vividvilla/simplesessions/stores/redis/v3 v3.0.0 + github.com/vividvilla/simplesessions/stores/securecookie/v3 v3.0.0 + github.com/vividvilla/simplesessions/v3 v3.0.0 + github.com/zerodha/fastglue v1.8.0 +) + +require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/fasthttp/router v1.4.5 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect + github.com/klauspost/compress v1.15.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/savsgio/gotils v0.0.0-20211223103454-d0aaa54c5899 // indirect + github.com/stretchr/testify v1.9.0 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace ( + github.com/vividvilla/simplesessions/stores/memory/v3 => ../stores/memory + github.com/vividvilla/simplesessions/stores/redis/v3 => ../stores/redis + github.com/vividvilla/simplesessions/stores/securecookie/v3 => ../stores/securecookie + github.com/vividvilla/simplesessions/v3 => ../ +) diff --git a/examples/nethttp-inmemory/go.mod b/examples/nethttp-inmemory/go.mod deleted file mode 100644 index 2ed9491..0000000 --- a/examples/nethttp-inmemory/go.mod +++ /dev/null @@ -1,8 +0,0 @@ -module github.com/vividvilla/simplesessions/examples/nethttp-inmemory - -go 1.14 - -require ( - github.com/vividvilla/simplesessions/stores/memory/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 -) diff --git a/examples/nethttp-inmemory/main.go b/examples/nethttp-inmemory/main.go index 6382b34..d79bed9 100644 --- a/examples/nethttp-inmemory/main.go +++ b/examples/nethttp-inmemory/main.go @@ -5,31 +5,31 @@ import ( "log" "net/http" - "github.com/vividvilla/simplesessions/stores/memory/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/stores/memory/v3" + "github.com/vividvilla/simplesessions/v3" ) var ( - sessionManager *simplesessions.Manager - + sessMgr *simplesessions.Manager testKey = "abc123" testValue = 123456 ) func setHandler(w http.ResponseWriter, r *http.Request) { - sess, err := sessionManager.Acquire(r, w, nil) - if err != nil { - http.Error(w, err.Error(), 500) - return + sess, err := sessMgr.Acquire(nil, r, w) + + // Create new session if it doesn't exist. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMgr.NewSession(r, w) } - err = sess.Set(testKey, testValue) if err != nil { http.Error(w, err.Error(), 500) return } - if err = sess.Commit(); err != nil { + err = sess.Set(testKey, testValue) + if err != nil { http.Error(w, err.Error(), 500) return } @@ -38,7 +38,7 @@ func setHandler(w http.ResponseWriter, r *http.Request) { } func getHandler(w http.ResponseWriter, r *http.Request) { - sess, err := sessionManager.Acquire(r, w, nil) + sess, err := sessMgr.Acquire(nil, r, w) if err != nil { http.Error(w, err.Error(), 500) return @@ -70,10 +70,9 @@ func setCookie(cookie *http.Cookie, w interface{}) error { } func main() { - sessionManager = simplesessions.New(simplesessions.Options{}) - sessionManager.UseStore(memory.New()) - sessionManager.RegisterGetCookie(getCookie) - sessionManager.RegisterSetCookie(setCookie) + sessMgr = simplesessions.New(simplesessions.Options{}) + sessMgr.UseStore(memory.New()) + sessMgr.SetCookieHooks(getCookie, setCookie) http.HandleFunc("/set", setHandler) http.HandleFunc("/get", getHandler) diff --git a/examples/nethttp-redis/go.mod b/examples/nethttp-redis/go.mod deleted file mode 100644 index ee59b93..0000000 --- a/examples/nethttp-redis/go.mod +++ /dev/null @@ -1,9 +0,0 @@ -module github.com/vividvilla/simplesessions/examples/nethttp-redis - -go 1.14 - -require ( - github.com/gomodule/redigo v2.0.0+incompatible - github.com/vividvilla/simplesessions/stores/redis/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 -) diff --git a/examples/nethttp-redis/main.go b/examples/nethttp-redis/main.go index db7fc68..e2d525b 100644 --- a/examples/nethttp-redis/main.go +++ b/examples/nethttp-redis/main.go @@ -1,37 +1,39 @@ package main import ( + "context" "fmt" "log" "net/http" "time" - "github.com/gomodule/redigo/redis" - redisstore "github.com/vividvilla/simplesessions/stores/redis/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/redis/go-redis/v9" + redisstore "github.com/vividvilla/simplesessions/stores/redis/v3" + "github.com/vividvilla/simplesessions/v3" ) var ( - sessionManager *simplesessions.Manager + sessMgr *simplesessions.Manager testKey = "abc123" testValue = 123456 ) func setHandler(w http.ResponseWriter, r *http.Request) { - sess, err := sessionManager.Acquire(r, w, nil) - if err != nil { - http.Error(w, err.Error(), 500) - return + sess, err := sessMgr.Acquire(nil, r, w) + + // Create new session if it doesn't exist. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMgr.NewSession(r, w) } - err = sess.Set(testKey, testValue) if err != nil { http.Error(w, err.Error(), 500) return } - if err = sess.Commit(); err != nil { + err = sess.Set(testKey, testValue) + if err != nil { http.Error(w, err.Error(), 500) return } @@ -40,7 +42,7 @@ func setHandler(w http.ResponseWriter, r *http.Request) { } func getHandler(w http.ResponseWriter, r *http.Request) { - sess, err := sessionManager.Acquire(r, w, nil) + sess, err := sessMgr.Acquire(nil, r, w) if err != nil { http.Error(w, err.Error(), 500) return @@ -71,34 +73,31 @@ func setCookie(cookie *http.Cookie, w interface{}) error { return nil } -func getRedisPool(address string, password string, maxActive int, maxIdle int, timeout time.Duration) *redis.Pool { - return &redis.Pool{ - Wait: true, - MaxActive: maxActive, - MaxIdle: maxIdle, - Dial: func() (redis.Conn, error) { - c, err := redis.Dial( - "tcp", - address, - redis.DialPassword(password), - redis.DialConnectTimeout(timeout), - redis.DialReadTimeout(timeout), - redis.DialWriteTimeout(timeout), - ) - - return c, err - }, +func getRedisPool() redis.UniversalClient { + o := &redis.Options{ + Addr: "localhost:6379", + Username: "", + Password: "", + DialTimeout: time.Second * 3, + DB: 0, + } + + var ( + ctx = context.TODO() + cl = redis.NewClient(o) + ) + if err := cl.Ping(ctx).Err(); err != nil { + log.Fatalf("error initializing redis: %v", err) } + + return cl } func main() { - rPool := getRedisPool("localhost:6379", "", 10, 10, 1000*time.Millisecond) - - sessionManager = simplesessions.New(simplesessions.Options{}) - store := redisstore.New(rPool) - sessionManager.UseStore(store) - sessionManager.RegisterGetCookie(getCookie) - sessionManager.RegisterSetCookie(setCookie) + sessMgr = simplesessions.New(simplesessions.Options{}) + store := redisstore.New(context.Background(), getRedisPool()) + sessMgr.UseStore(store) + sessMgr.SetCookieHooks(getCookie, setCookie) http.HandleFunc("/set", setHandler) http.HandleFunc("/get", getHandler) diff --git a/examples/nethttp-secure-cookie/go.mod b/examples/nethttp-secure-cookie/go.mod deleted file mode 100644 index f8debc2..0000000 --- a/examples/nethttp-secure-cookie/go.mod +++ /dev/null @@ -1,10 +0,0 @@ -module github.com/vividvilla/simplesessions/examples/nethttp-secure-cookie - -go 1.14 - -require ( - github.com/vividvilla/simplesessions/stores/securecookie/v2 v2.0.0 - github.com/vividvilla/simplesessions/v2 v2.0.0 -) - -require github.com/gorilla/securecookie v1.1.2 // indirect diff --git a/examples/nethttp-secure-cookie/main.go b/examples/nethttp-secure-cookie/main.go index 29fd47d..3999476 100644 --- a/examples/nethttp-secure-cookie/main.go +++ b/examples/nethttp-secure-cookie/main.go @@ -5,12 +5,12 @@ import ( "log" "net/http" - "github.com/vividvilla/simplesessions/stores/securecookie/v2" - "github.com/vividvilla/simplesessions/v2" + "github.com/vividvilla/simplesessions/stores/securecookie/v3" + "github.com/vividvilla/simplesessions/v3" ) var ( - sessionManager *simplesessions.Manager + sessMgr *simplesessions.Manager store = securecookie.New( []byte("0dIHy6S2uBuKaNnTUszB218L898ikGYA"), @@ -22,7 +22,21 @@ var ( ) func setHandler(w http.ResponseWriter, r *http.Request) { - sess, err := sessionManager.Acquire(r, w, nil) + sess, err := sessMgr.Acquire(nil, r, w) + // Create new session if it doesn't exist. + if err == simplesessions.ErrInvalidSession { + sess, err = sessMgr.NewSession(r, w) + + // IMPORTANT: any Set/SetMulti/Delete/Clear/Destroy and NewSession() + // should flush the values using `store.Flush()` otherwise cookie won't be updated. + if err == nil { + ck, err := store.Flush(sess.ID()) + if err == nil { + err = sess.WriteCookie(ck) + } + } + } + if err != nil { http.Error(w, err.Error(), 500) return @@ -52,7 +66,7 @@ func setHandler(w http.ResponseWriter, r *http.Request) { } func getHandler(w http.ResponseWriter, r *http.Request) { - sess, err := sessionManager.Acquire(r, w, nil) + sess, err := sessMgr.Acquire(nil, r, w) if err != nil { http.Error(w, err.Error(), 500) return @@ -84,11 +98,9 @@ func setCookie(cookie *http.Cookie, w interface{}) error { } func main() { - sessionManager = simplesessions.New(simplesessions.Options{}) - sessionManager.UseStore(store) - - sessionManager.RegisterGetCookie(getCookie) - sessionManager.RegisterSetCookie(setCookie) + sessMgr = simplesessions.New(simplesessions.Options{}) + sessMgr.UseStore(store) + sessMgr.SetCookieHooks(getCookie, setCookie) http.HandleFunc("/set", setHandler) http.HandleFunc("/get", getHandler) diff --git a/go.mod b/go.mod index b8af66f..49d9411 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,11 @@ -module github.com/vividvilla/simplesessions/v2 +module github.com/vividvilla/simplesessions/v3 + +require github.com/stretchr/testify v1.9.0 require ( - github.com/stretchr/testify v1.9.0 - github.com/valyala/fasthttp v1.40.0 + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.14 +go 1.18 diff --git a/go.work b/go.work index f692f89..2b82182 100644 --- a/go.work +++ b/go.work @@ -1,11 +1,10 @@ -go 1.14 +go 1.18 use ( . - ./conv - ./stores/goredis ./stores/memory + ./stores/postgres ./stores/redis ./stores/securecookie - ./stores/postgres + ./examples ) diff --git a/manager.go b/manager.go index e46cd51..dd18f24 100644 --- a/manager.go +++ b/manager.go @@ -2,20 +2,31 @@ package simplesessions import ( "context" + "crypto/rand" "fmt" "net/http" "time" + "unicode" ) +// Context name type. +type ctxNameType string + const ( // Default cookie name used to store session. defaultCookieName = "session" + // Default cookie path. + defaultCookiePath = "/" + + // default sessionID length. + defaultSessIDLength = 32 + // ContextName is the key used to store session in context passed to acquire method. - ContextName = "_simple_session" + ContextName ctxNameType = "_simple_session" ) -// Manager is a utility to scaffold session and store. +// Manager handles the storage and management of HTTP cookies. type Manager struct { // Store to be used. store Store @@ -23,39 +34,64 @@ type Manager struct { // Store basic cookie details. opts *Options - // Callback to get http cookie. - getCookieCb func(name string, r interface{}) (*http.Cookie, error) + // Hook to get http cookie. + getCookieHook func(name string, r interface{}) (*http.Cookie, error) - // Callback to set http cookie. - setCookieCb func(cookie *http.Cookie, w interface{}) error + // Hook to set http cookie. + setCookieHook func(cookie *http.Cookie, w interface{}) error + + // generate cookie ID. + generateID func() (string, error) + + // validate cookie ID. + validateID func(string) bool } -// Options are available options to configure Manager. +// Options to configure manager and cookie. type Options struct { - // DisableAutoSet skips creation of session cookie in frontend and new session in store if session is not already set. - DisableAutoSet bool + // If enabled, Acquire() will always create and return a new session if one doesn't already exist. + // If disabled then new session can only be created using NewSession() method. + EnableAutoCreate bool - // CookieName sets http cookie name. This is also sent as cookie name in `GetCookie` callback. - CookieName string + // Cookie ID length. Defaults to alphanumeric 32 characters. + // Might not be applicable to some stores like SecureCookie. + // Also not applicable if custom generateID and validateID is set. + SessionIDLength int - // CookieDomain sets hostname for the cookie. Domain specifies allowed hosts to receive the cookie. - CookieDomain string + // Cookie options. + Cookie CookieOptions +} - // CookiePath sets path for the cookie. Path indicates a URL path that must exist in the requested URL in order to send the cookie header. - CookiePath string +type CookieOptions struct { + // Name sets http cookie name. This is also sent as cookie name in `GetCookie` callback. + Name string - // IsSecureCookie marks the cookie as secure cookie (only sent in HTTPS). - IsSecureCookie bool + // Domain sets hostname for the cookie. Domain specifies allowed hosts to receive the cookie. + Domain string - // IsHTTPOnlyCookie marks the cookie as http only cookie. JS won't be able to access the cookie so prevents XSS attacks. - IsHTTPOnlyCookie bool + // Path sets path for the cookie. Path indicates a URL path that must exist in the requested URL in order to send the cookie header. + Path string - // CookieLifeTime sets expiry time for cookie. - // If expiry time is not specified then cookie is set as session cookie which is cleared on browser close. - CookieLifetime time.Duration + // IsSecure marks the cookie as secure cookie (only sent in HTTPS). + IsSecure bool + + // IsHTTPOnly marks the cookie as http only cookie. JS won't be able to access the cookie so prevents XSS attacks. + IsHTTPOnly bool // SameSite sets allows you to declare if your cookie should be restricted to a first-party or same-site context. SameSite http.SameSite + + // Expires sets absolute expiration date and time for the cookie. + // If both Expires and MaxAge are sent then MaxAge takes precedence over Expires. + // Cookies without a Max-age or Expires attribute – are deleted when the current session ends + // and some browsers use session restoring when restarting. This can cause session cookies to last indefinitely. + Expires time.Time + + // Sets the cookie's expiration in seconds from the current time, internally its rounder off to nearest seconds. + // If both Expires and MaxAge are sent then MaxAge takes precedence over Expires. + // Cookies without a Max-age or Expires attribute – are deleted when the current session ends + // and some browsers use session restoring when restarting. This can cause session cookies to last indefinitely. + MaxAge time.Duration } // New creates a new session manager for given options. @@ -65,15 +101,23 @@ func New(opts Options) *Manager { } // Set default cookie name if not set - if m.opts.CookieName == "" { - m.opts.CookieName = defaultCookieName + if m.opts.Cookie.Name == "" { + m.opts.Cookie.Name = defaultCookieName } // If path not given then set to root path - if m.opts.CookiePath == "" { - m.opts.CookiePath = "/" + if m.opts.Cookie.Path == "" { + m.opts.Cookie.Path = defaultCookiePath + } + + if m.opts.SessionIDLength == 0 { + m.opts.SessionIDLength = defaultSessIDLength } + // Assign default set and validate generate ID. + m.generateID = m.defaultGenerateID + m.validateID = m.defaultValidateID + return m } @@ -82,39 +126,87 @@ func (m *Manager) UseStore(str Store) { m.store = str } -// RegisterGetCookie sets a callback to get http cookie from any reader interface which -// is sent on session acquisition using `Acquire` method. -func (m *Manager) RegisterGetCookie(cb func(string, interface{}) (*http.Cookie, error)) { - m.getCookieCb = cb +// SetCookieHooks cane be used to get and set HTTP cookie for the session. +// +// getCookie hook takes session ID and reader interface and returns http.Cookie and error. +// In a HTTP request context reader interface will be the http request object and +// it should obtain http.Cookie from the request object for the given cookie ID. +// +// setCookie hook takes http.Cookie object and a writer interface and returns error. +// In a HTTP request context the write interface will be the http request object and +// it should write http request with the incoming cookie. +func (m *Manager) SetCookieHooks(getCookie func(string, interface{}) (*http.Cookie, error), setCookie func(*http.Cookie, interface{}) error) { + m.getCookieHook = getCookie + m.setCookieHook = setCookie +} + +// SetSessionIDHooks cane be used to generate and validate custom session ID. +// Bydefault alpha-numeric 32bit length session ID is used if its not set. +// - Generating custom session ID, which will be uses as the ID for storing sessions in the backend. +// - Validating custom session ID, which will be used to verify the ID before querying backend. +func (m *Manager) SetSessionIDHooks(generateID func() (string, error), validateID func(string) bool) { + m.generateID = generateID + m.validateID = validateID } -// RegisterSetCookie sets a callback to set cookie from http writer interface which -// is sent on session acquisition using `Acquire` method. -func (m *Manager) RegisterSetCookie(cb func(*http.Cookie, interface{}) error) { - m.setCookieCb = cb +// NewSession creates a new `Session` and updates the cookie with a new session ID, +// replacing any existing session ID if it exists. +func (m *Manager) NewSession(r, w interface{}) (*Session, error) { + // Check if any store is set + if m.store == nil { + return nil, fmt.Errorf("session store not set") + } + + if m.setCookieHook == nil { + return nil, fmt.Errorf("`SetCookie` hook not set") + } + + // Create new cookie in store and write to front. + // Store also calls `WriteCookie`` to write to http interface. + id, err := m.generateID() + if err != nil { + return nil, errAs(err) + } + + if err = m.store.Create(id); err != nil { + return nil, errAs(err) + } + + var sess = &Session{ + id: id, + manager: m, + reader: r, + writer: w, + cache: nil, + } + // Write cookie. + if err := sess.WriteCookie(id); err != nil { + return nil, err + } + + return sess, nil } -// Acquire gets a `Session` for current session cookie from store. -// If `Session` is not found on store then it creates a new session and sets on store. -// If 'DisableAutoSet` is set in options then session has to be explicitly created before -// using `Session` for getting or setting. -// `r` and `w` is request and response interfaces which are sent back in GetCookie and SetCookie callbacks respectively. -// In case of net/http `r` will be r` -// Optionally context can be passed around which is used to get already loaded session. This is useful when -// handler is wrapped with multiple middlewares and `Acquire` is already called in any of the middleware. -func (m *Manager) Acquire(r, w interface{}, c context.Context) (*Session, error) { +// Acquire retrieves a `Session` from the store using the current session cookie. +// +// If session not found and `opt.EnableAutoCreate` is true, a new session is created and returned. +// If session not found and `opt.EnableAutoCreate` is false which is the default, it returns `ErrInvalidSession`. +// +// `r` and `w` are request and response interfaces which is passed back in in GetCookie and SetCookie callbacks. +// Optionally, a context can be passed to get an already loaded session, useful in middleware chains. +func (m *Manager) Acquire(c context.Context, r, w interface{}) (*Session, error) { // Check if any store is set if m.store == nil { - return nil, fmt.Errorf("session store is not set") + return nil, fmt.Errorf("session store not set") } // Check if callbacks are set - if m.getCookieCb == nil { - return nil, fmt.Errorf("callback `GetCookie` not set") + if m.getCookieHook == nil { + return nil, fmt.Errorf("`GetCookie` hook not set") } - if m.setCookieCb == nil { - return nil, fmt.Errorf("callback `SetCookie` not set") + if m.setCookieHook == nil { + return nil, fmt.Errorf("`SetCookie` hook not set") } // If a session was already set in the context by a middleware somewhere, return that. @@ -124,5 +216,58 @@ func (m *Manager) Acquire(r, w interface{}, c context.Context) (*Session, error) } } - return NewSession(m, r, w) + // Get existing HTTP session cookie. + // If there's no error and there's a session ID (unvalidated at this point), + // return a session object. + ck, err := m.getCookieHook(m.opts.Cookie.Name, r) + if err == nil && ck != nil && ck.Value != "" { + return &Session{ + manager: m, + reader: r, + writer: w, + id: ck.Value, + cache: nil, + }, nil + } + + // If auto-creation is disabled, return an error. + if !m.opts.EnableAutoCreate { + return nil, ErrInvalidSession + } + + return m.NewSession(r, w) +} + +// defaultGenerateID generates a random alpha-num session ID. +// This will be the default method to generate cookie ID and +// can override using `SetCookieIDGenerate` method. +func (m *Manager) defaultGenerateID() (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, m.opts.SessionIDLength) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} + +// defaultValidateID validates the incoming to ID to check +// if its alpha-numeric with configured cookie ID length. +// Can override using `SetCookieIDGenerate` method. +func (m *Manager) defaultValidateID(id string) bool { + if len(id) != m.opts.SessionIDLength { + return false + } + + for _, r := range id { + if !unicode.IsDigit(r) && !unicode.IsLetter(r) { + return false + } + } + + return true } diff --git a/manager_test.go b/manager_test.go index 2ba4d7b..eac02fd 100644 --- a/manager_test.go +++ b/manager_test.go @@ -2,6 +2,7 @@ package simplesessions import ( "context" + "fmt" "net/http" "testing" "time" @@ -9,167 +10,217 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewManagerWithDefaultOptions(t *testing.T) { +const mockSessionID = "sometestcookievalue" + +func newMockStore() *MockStore { + return &MockStore{ + id: mockSessionID, + data: map[string]interface{}{}, + err: nil, + } +} + +func newMockManager(store *MockStore) *Manager { m := New(Options{}) + m.UseStore(store) + m.SetCookieHooks(mockGetCookieCb, mockSetCookieCb) + return m +} - assert := assert.New(t) +func mockGetCookieCb(name string, r interface{}) (*http.Cookie, error) { + return &http.Cookie{ + Name: name, + Value: mockSessionID, + }, nil +} + +func mockSetCookieCb(*http.Cookie, interface{}) error { + return nil +} + +func TestNewManagerWithDefaultOptions(t *testing.T) { + m := New(Options{}) // Default cookie path is set to root - assert.Equal(m.opts.CookiePath, "/") + assert.Equal(t, "/", m.opts.Cookie.Path) // Default cookie name is set - assert.Equal(m.opts.CookieName, defaultCookieName) + assert.Equal(t, defaultCookieName, m.opts.Cookie.Name) } func TestManagerNewManagerWithOptions(t *testing.T) { opts := Options{ - DisableAutoSet: true, - CookieName: "testcookiename", - CookieDomain: "somedomain", - CookiePath: "/abc/123", - IsSecureCookie: true, - IsHTTPOnlyCookie: true, - SameSite: http.SameSiteLaxMode, - CookieLifetime: 2000 * time.Millisecond, + EnableAutoCreate: true, + SessionIDLength: 16, + Cookie: CookieOptions{ + Name: "testcookiename", + Domain: "somedomain", + Path: "/abc/123", + IsSecure: true, + IsHTTPOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: time.Hour * 1, + Expires: time.Now(), + }, } m := New(opts) - - assert := assert.New(t) - - // Default cookie path is set to root - assert.Equal(m.opts.DisableAutoSet, opts.DisableAutoSet) - assert.Equal(m.opts.CookieName, opts.CookieName) - assert.Equal(m.opts.CookieDomain, opts.CookieDomain) - assert.Equal(m.opts.CookiePath, opts.CookiePath) - assert.Equal(m.opts.IsSecureCookie, opts.IsSecureCookie) - assert.Equal(m.opts.SameSite, opts.SameSite) - assert.Equal(m.opts.IsHTTPOnlyCookie, opts.IsHTTPOnlyCookie) - assert.Equal(m.opts.CookieLifetime, opts.CookieLifetime) + assert.Equal(t, opts.EnableAutoCreate, m.opts.EnableAutoCreate) + assert.Equal(t, opts.SessionIDLength, m.opts.SessionIDLength) + assert.Equal(t, opts.Cookie.Name, m.opts.Cookie.Name) + assert.Equal(t, opts.Cookie.Domain, m.opts.Cookie.Domain) + assert.Equal(t, opts.Cookie.Path, m.opts.Cookie.Path) + assert.Equal(t, opts.Cookie.IsSecure, m.opts.Cookie.IsSecure) + assert.Equal(t, opts.Cookie.SameSite, m.opts.Cookie.SameSite) + assert.Equal(t, opts.Cookie.IsHTTPOnly, m.opts.Cookie.IsHTTPOnly) + assert.Equal(t, opts.Cookie.MaxAge, m.opts.Cookie.MaxAge) + assert.Equal(t, opts.Cookie.Expires, m.opts.Cookie.Expires) + + // Default opts. + m = New(Options{}) + assert.NotNil(t, m.generateID) + assert.NotNil(t, m.validateID) + + assert.Equal(t, false, m.opts.EnableAutoCreate) + assert.Equal(t, defaultSessIDLength, m.opts.SessionIDLength) + assert.Equal(t, defaultCookieName, m.opts.Cookie.Name) + assert.Equal(t, defaultCookiePath, m.opts.Cookie.Path) } func TestManagerUseStore(t *testing.T) { - assert := assert.New(t) - mockStr := &MockStore{} - assert.Implements((*Store)(nil), mockStr) - - m := New(Options{}) - m.UseStore(mockStr) - assert.Equal(m.store, mockStr) + s := newMockStore() + m := newMockManager(s) + assert.Equal(t, s, m.store) } -func TestManagerRegisterGetCookie(t *testing.T) { - assert := assert.New(t) - m := New(Options{}) - - testCookie := &http.Cookie{ +func TestManagerSetCookieHooks(t *testing.T) { + ck := &http.Cookie{ Name: "testcookie", } - cb := func(string, interface{}) (*http.Cookie, error) { - return testCookie, http.ErrNoCookie + get := func(string, interface{}) (*http.Cookie, error) { + return ck, http.ErrNoCookie } - - m.RegisterGetCookie(cb) - - expectCbRes, expectCbErr := cb("", nil) - actualCbRes, actualCbErr := m.getCookieCb("", nil) - - assert.Equal(expectCbRes, actualCbRes) - assert.Equal(expectCbErr, actualCbErr) -} - -func TestManagerRegisterSetCookie(t *testing.T) { - assert := assert.New(t) - m := New(Options{}) - - testCookie := &http.Cookie{ - Name: "testcookie", - } - - cb := func(*http.Cookie, interface{}) error { + set := func(*http.Cookie, interface{}) error { return http.ErrNoCookie } - m.RegisterSetCookie(cb) + m := New(Options{}) + m.SetCookieHooks(get, set) - expectCbErr := cb(testCookie, nil) - actualCbErr := m.setCookieCb(testCookie, nil) + expRes, expErr := get("", nil) + gotRes, gotErr := m.getCookieHook("", nil) + assert.Equal(t, expRes, gotRes) + assert.Equal(t, expErr, gotErr) - assert.Equal(expectCbErr, actualCbErr) + expErr = set(ck, nil) + gotErr = m.setCookieHook(ck, nil) + assert.Equal(t, expErr, gotErr) } func TestManagerAcquireFails(t *testing.T) { - assert := assert.New(t) m := New(Options{}) - _, err := m.Acquire(nil, nil, nil) - assert.Error(err, "session store is not set") + // Fail if store is not assigned. + _, err := m.Acquire(context.Background(), nil, nil) + assert.Equal(t, "session store not set", err.Error()) + // Fail if getCookie callback is not assigned. m.UseStore(&MockStore{}) - _, err = m.Acquire(nil, nil, nil) - assert.Error(err, "callback `GetCookie` not set") + _, err = m.Acquire(context.Background(), nil, nil) + assert.Equal(t, "`GetCookie` hook not set", err.Error()) - getCb := func(string, interface{}) (*http.Cookie, error) { - return nil, nil - } - m.RegisterGetCookie(getCb) - _, err = m.Acquire(nil, nil, nil) - assert.Error(err, "callback `SetCookie` not set") -} + // Assign getCookie, returns nil cookie to make sure it + // fails in create session with invalid session. + m.SetCookieHooks(func(string, interface{}) (*http.Cookie, error) { return nil, nil }, nil) -func TestManagerAcquireSucceeds(t *testing.T) { - m := New(Options{}) - m.UseStore(&MockStore{ - isValid: true, - }) + // Fail if setCookie callback is not assigned. + _, err = m.Acquire(context.Background(), nil, nil) + assert.Equal(t, "`SetCookie` hook not set", err.Error()) - getCb := func(string, interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: "testcookie", - Value: "", - }, nil - } - m.RegisterGetCookie(getCb) + // Register setCookie callback. + m.SetCookieHooks(func(string, interface{}) (*http.Cookie, error) { return nil, nil }, + func(*http.Cookie, interface{}) error { return nil }) - setCb := func(*http.Cookie, interface{}) error { - return http.ErrNoCookie - } - m.RegisterSetCookie(setCb) + // By default EnableAutoCreate is disabled + // Check if it returns invalid session. + _, err = m.Acquire(context.Background(), nil, nil) + assert.ErrorIs(t, err, ErrInvalidSession) +} - _, err := m.Acquire(nil, nil, nil) - assert := assert.New(t) - assert.NoError(err) +func TestManagerAcquireAutocreate(t *testing.T) { + m := newMockManager(newMockStore()) + // Enable autocreate. + m.opts.EnableAutoCreate = true + m.SetCookieHooks(func(string, interface{}) (*http.Cookie, error) { return nil, ErrInvalidSession }, + func(*http.Cookie, interface{}) error { return nil }) + + // If cookie doesn't exist then should return a new one without error. + sess, err := m.Acquire(context.Background(), nil, nil) + assert.NoError(t, err) + assert.True(t, m.validateID(sess.id)) } func TestManagerAcquireFromContext(t *testing.T) { assert := assert.New(t) + m := newMockManager(newMockStore()) + + sess, err := m.Acquire(context.Background(), nil, nil) + sess.id = "updated" + assert.NoError(err) + + ctx := context.WithValue(context.Background(), ContextName, sess) + sessNext, err := m.Acquire(ctx, nil, nil) + assert.Equal(sess.id, sessNext.id) + assert.NoError(err) +} + +func TestDefaultGenerateID(t *testing.T) { m := New(Options{}) - m.UseStore(&MockStore{ - isValid: true, + id, err := m.generateID() + assert.NoError(t, err) + assert.Equal(t, defaultSessIDLength, len(id)) + + m = New(Options{ + SessionIDLength: 16, }) + id, err = m.generateID() + assert.NoError(t, err) + assert.Equal(t, 16, len(id)) +} - getCb := func(string, interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: "testcookie", - Value: "", - }, nil - } - m.RegisterGetCookie(getCb) +func TestDefaultValidateID(t *testing.T) { + m := New(Options{}) + id, err := m.generateID() + assert.NoError(t, err) + assert.True(t, m.validateID(id)) + assert.False(t, m.validateID("xxxx")) + assert.False(t, m.validateID("11IHy6S2uBuKaNnTUszB218L898ikGY*")) +} - setCb := func(*http.Cookie, interface{}) error { - return http.ErrNoCookie +func TestSetSessionIDHooks(t *testing.T) { + var ( + m = New(Options{}) + genErr error = nil + genID = "xxx" + valOut = true + ) + gen := func() (string, error) { + return genID, genErr } - m.RegisterSetCookie(setCb) + validate := func(string) bool { + return valOut + } + m.SetSessionIDHooks(gen, validate) - sess, err := m.Acquire(nil, nil, nil) - assert.NoError(err) - sess.cookie.Value = "updated" + id, err := m.generateID() + eID, eErr := gen() + assert.Equal(t, eID, id) + assert.Equal(t, eErr, err) - sessNew, err := m.Acquire(nil, nil, nil) - assert.NoError(err) - assert.NotEqual(sessNew.cookie.Value, sess.cookie.Value) + genErr = fmt.Errorf("custom error") + _, err = m.generateID() + assert.ErrorIs(t, genErr, err) - ctx := context.Background() - ctx = context.WithValue(ctx, ContextName, sess) - sessNext, err := m.Acquire(nil, nil, ctx) - assert.Equal(sessNext.cookie.Value, sess.cookie.Value) + assert.True(t, m.validateID(genID)) + valOut = false + assert.False(t, m.validateID(genID)) } diff --git a/session.go b/session.go index e34d742..d6decfb 100644 --- a/session.go +++ b/session.go @@ -3,30 +3,27 @@ package simplesessions import ( "errors" "net/http" + "sync" "time" ) -// Session is utility for get, set or clear session. +// Session represents a session object used for retrieving/setting session data and cookies. type Session struct { - // Map to store session data which can be loaded using `Load` method. - // Get session method check if the field is available here before getting from store directly. - values map[string]interface{} + // Map to store session data, loaded using `CacheAll` method. + // All `Get` methods tries to retrive cached value before fetching from the store. + // If its nil then cache is not set and `Get` methods directly fetch from the store. + cache map[string]interface{} + cacheMux sync.RWMutex // Session manager. manager *Manager - // Current http cookie. This is passed down to `SetCookie` callback. - cookie *http.Cookie + // Session ID. + id string - // HTTP reader and writer interfaces which are passed on to - // `GetCookie`` and `SetCookie`` callback respectively. + // HTTP reader and writer interfaces which are passed on to `GetCookie`` and `SetCookie`` callbacks. reader interface{} writer interface{} - - // Track if session is set in store or not - // used to throw and error is autoSet is not enabled and user - // explicitly didn't create new session in store. - isSet bool } var ( @@ -35,320 +32,289 @@ var ( // Store code = 1 ErrInvalidSession = errors.New("simplesession: invalid session") - // ErrFieldNotFound is raised when given key is not found in store + // ErrNil is raised when returned value is nil. // Store code = 2 - ErrFieldNotFound = errors.New("simplesession: session field not found in store") + ErrNil = errors.New("simplesession: nil returned") // ErrAssertType is raised when type assertion fails // Store code = 3 ErrAssertType = errors.New("simplesession: invalid type assertion") - - // ErrNil is raised when returned value is nil. - // Store code = 4 - ErrNil = errors.New("simplesession: nil returned") ) type errCode interface { Code() int } -// NewSession creates a new session. Reads cookie info from `GetCookie“ callback -// and validate the session with current store. If cookie not set then it creates -// new session and calls `SetCookie“ callback. If `DisableAutoSet` is set then it -// skips new session creation and should be manually done using `Create` method. -// If a cookie is found but its invalid in store then `ErrInvalidSession` error is returned. -func NewSession(m *Manager, r, w interface{}) (*Session, error) { - var ( - err error - sess = &Session{ - manager: m, - reader: r, - writer: w, - values: make(map[string]interface{}), - } - ) +// WriteCookie writes the cookie for the given session ID. +// Uses all the cookie options set in Manager. +func (s *Session) WriteCookie(id string) error { + ck := &http.Cookie{ + Value: id, + Name: s.manager.opts.Cookie.Name, + Domain: s.manager.opts.Cookie.Domain, + Path: s.manager.opts.Cookie.Path, + Secure: s.manager.opts.Cookie.IsSecure, + HttpOnly: s.manager.opts.Cookie.IsHTTPOnly, + SameSite: s.manager.opts.Cookie.SameSite, + Expires: s.manager.opts.Cookie.Expires, + MaxAge: int(s.manager.opts.Cookie.MaxAge.Seconds()), + } - // Get existing http session cookie - sess.cookie, err = m.getCookieCb(m.opts.CookieName, r) + // Call `SetCookie` callback to write cookie to response + return s.manager.setCookieHook(ck, s.writer) +} - // Create new session - if err == http.ErrNoCookie { - // Skip creating new cookie in store. User has to manually create before doing Get or Set. - if m.opts.DisableAutoSet { - return sess, nil - } +// ClearCookie sets the cookie's expiry to one day prior to clear it. +func (s *Session) ClearCookie() error { + ck := &http.Cookie{ + Name: s.manager.opts.Cookie.Name, + Value: "", + // Set expiry to previous date to clear it from browser + Expires: time.Now().AddDate(0, 0, -1), + } - // Create new cookie in store and write to front - // Store also calls `WriteCookie`` to write to http interface - cv, err := m.store.Create() - if err != nil { - return nil, errAs(err) - } + // Call `SetCookie` callback to write cookie to response + return s.manager.setCookieHook(ck, s.writer) +} - // Write cookie - if err := sess.WriteCookie(cv); err != nil { - return nil, err - } +// ID returns the acquired session ID. If cookie is not set then empty string is returned. +func (s *Session) ID() string { + return s.id +} + +// getCacheAll returns a copy of cached map. +func (s *Session) getCacheAll() map[string]interface{} { + s.cacheMux.RLock() + defer s.cacheMux.RUnlock() - // Set isSet flag - sess.isSet = true - } else if err != nil { - return nil, err + if s.cache == nil { + return nil } - // Set isSet flag - sess.isSet = true + out := map[string]interface{}{} + for k, v := range s.cache { + out[k] = v + } - return sess, nil + return out } -// WriteCookie updates the cookie and calls `SetCookie` callback. -// This method can also be used by store to update cookie whenever the cookie value changes. -func (s *Session) WriteCookie(cv string) error { - s.cookie = &http.Cookie{ - Value: cv, - Name: s.manager.opts.CookieName, - Domain: s.manager.opts.CookieDomain, - Path: s.manager.opts.CookiePath, - Secure: s.manager.opts.IsSecureCookie, - HttpOnly: s.manager.opts.IsHTTPOnlyCookie, - SameSite: s.manager.opts.SameSite, +// getCacheAll returns a map of values for the given list of keys. +// If key doesn't exist then Nil is returned. +func (s *Session) getCache(key ...string) map[string]interface{} { + s.cacheMux.RLock() + defer s.cacheMux.RUnlock() + + if s.cache == nil { + return nil } - // Set cookie expiry - if s.manager.opts.CookieLifetime != 0 { - s.cookie.Expires = time.Now().Add(s.manager.opts.CookieLifetime) + out := map[string]interface{}{} + for _, k := range key { + v, ok := s.cache[k] + if ok { + out[k] = v + } else { + out[k] = nil + } } - // Call `SetCookie` callback to write cookie to response - return s.manager.setCookieCb(s.cookie, s.writer) + return out } -// clearCookie sets expiry of the cookie to one day before to clear it. -func (s *Session) clearCookie() error { - s.cookie = &http.Cookie{ - Name: s.manager.opts.CookieName, - Value: "", - // Set expiry to previous date to clear it from browser - Expires: time.Now().AddDate(0, 0, -1), +// setCache sets a cache for given kv pairs. +func (s *Session) setCache(data map[string]interface{}) { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + + // If cacheAll() is not called the don't maintain cache. + if s.cache == nil { + return } - // Call `SetCookie` callback to write cookie to response - return s.manager.setCookieCb(s.cookie, s.writer) + for k, v := range data { + s.cache[k] = v + } } -// Create a new session. This is implicit when option `DisableAutoSet` is false -// else session has to be manually created before setting or getting values. -func (s *Session) Create() error { - // Create new cookie in store and write to front. - cv, err := s.manager.store.Create() - if err != nil { - return errAs(err) - } +// deleteCache sets a cache for given kv pairs. +func (s *Session) deleteCache(key ...string) { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() - // Write cookie - if err := s.WriteCookie(cv); err != nil { - return err + // If cacheAll() is not called the don't maintain cache. + if s.cache == nil { + return } - // Set isSet flag - s.isSet = true - - return nil + for _, k := range key { + delete(s.cache, k) + } } -// ID returns the acquired session ID. If cookie is not set then empty string is returned. -func (s *Session) ID() string { - if s.cookie != nil { - return s.cookie.Value +// CacheAll loads session values into memory for quick access. +// Ideal for centralized session fetching, e.g., in middleware. +// Subsequent Get/GetMulti calls return cached values, avoiding store access. +// Use ResetCache() to ensure GetAll/Get/GetMulti fetches from the store. +func (s *Session) CacheAll() error { + all, err := s.manager.store.GetAll(s.id) + if err != nil { + return err } - return "" -} -// LoadValues loads the session values in memory. -// Get session field tries to get value from memory before hitting store. -func (s *Session) LoadValues() error { - var err error - s.values, err = s.GetAll() - return err + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + s.cache = map[string]interface{}{} + for k, v := range all { + s.cache[k] = v + } + + return nil } -// ResetValues reset the loaded values using `LoadValues` method.ResetValues -// Subsequent Get, GetAll and GetMulti -func (s *Session) ResetValues() { - s.values = make(map[string]interface{}) +// ResetCache clears loaded values, ensuring subsequent Get, GetAll, and GetMulti calls fetch from the store. +func (s *Session) ResetCache() { + s.cacheMux.Lock() + defer s.cacheMux.Unlock() + s.cache = nil } -// GetAll gets all the fields in the session. +// GetAll gets all the fields for the given session id. func (s *Session) GetAll() (map[string]interface{}, error) { - // Check if session is set before accessing it - if !s.isSet { - return nil, ErrInvalidSession - } - - // Load value from map if its already loaded - if len(s.values) > 0 { - return s.values, nil + // Try to get the values from cache. + c := s.getCacheAll() + if c != nil { + return c, nil } - out, err := s.manager.store.GetAll(s.cookie.Value) + // Get the values from store. + out, err := s.manager.store.GetAll(s.id) return out, errAs(err) } -// GetMulti gets a map of values for multiple session keys. -func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { - // Check if session is set before accessing it - if !s.isSet { - return nil, ErrInvalidSession +// GetMulti retrieves values for multiple session fields. +// If a field is not found in the store then its returned as nil. +func (s *Session) GetMulti(key ...string) (map[string]interface{}, error) { + // Try to get the values from cache. + c := s.getCache(key...) + if c != nil { + return c, nil } - // Load values from map if its already loaded - if len(s.values) > 0 { - vals := make(map[string]interface{}) - for _, k := range keys { - if v, ok := s.values[k]; ok { - vals[k] = v - } - } - - return vals, nil - } - - out, err := s.manager.store.GetMulti(s.cookie.Value, keys...) + out, err := s.manager.store.GetMulti(s.id, key...) return out, errAs(err) } -// Get gets a value for given key in session. -// If session is already loaded using `Load` then returns values from -// existing map instead of getting it from store. +// Get retrieves a value for the given key in the session. +// If the session is already loaded, it returns the value from the existing map. +// Otherwise, it fetches the value from the store. func (s *Session) Get(key string) (interface{}, error) { - // Check if session is set before accessing it - if !s.isSet { - return nil, ErrInvalidSession + // Try to get the values from cache. + // If cache is set then get only from cache. + c := s.getCache(key) + if c != nil { + return c[key], nil } - // Load value from map if its already loaded - if len(s.values) > 0 { - if val, ok := s.values[key]; ok { - return val, nil - } - } - - // Get from backend if not found in previous step - out, err := s.manager.store.Get(s.cookie.Value, key) + // Fetch from store if not found in the map. + out, err := s.manager.store.Get(s.id, key) return out, errAs(err) } -// Set sets a value for given key in session. Its up to store to commit -// all previously set values at once or store it on each set. +// Set assigns a value to the given key in the session. func (s *Session) Set(key string, val interface{}) error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession + err := s.manager.store.Set(s.id, key, val) + if err == nil { + s.setCache(map[string]interface{}{ + key: val, + }) } - - err := s.manager.store.Set(s.cookie.Value, key, val) return errAs(err) } -// SetMulti sets all values in the session. -// Its up to store to commit all previously -// set values at once or store it on each set. -func (s *Session) SetMulti(values map[string]interface{}) error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - for k, v := range values { - if err := s.manager.store.Set(s.cookie.Value, k, v); err != nil { - return errAs(err) - } +// SetMulti assigns multiple values to the session. +func (s *Session) SetMulti(data map[string]interface{}) error { + err := s.manager.store.SetMulti(s.id, data) + if err == nil { + s.setCache(data) } - - return nil + return errAs(err) } -// Commit commits all set to store. Its up to store to commit -// all previously set values at once or store it on each set. -func (s *Session) Commit() error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - if err := s.manager.store.Commit(s.cookie.Value); err != nil { - return errAs(err) +// Delete deletes a given list of fields from the session. +func (s *Session) Delete(key ...string) error { + err := s.manager.store.Delete(s.id, key...) + if err == nil { + s.deleteCache(key...) } - - return nil + return errAs(err) } -// Delete deletes a field from session. -func (s *Session) Delete(key string) error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - if err := s.manager.store.Delete(s.cookie.Value, key); err != nil { +// Clear empties the data for the given session id but doesn't clear the cookie. +// Use `Destroy()` to delete entire session from the store and clear the cookie. +func (s *Session) Clear() error { + err := s.manager.store.Clear(s.id) + if err != nil { return errAs(err) } - + s.ResetCache() return nil } -// Clear clears session data from store and clears the cookie -func (s *Session) Clear() error { - // Check if session is set before accessing it - if !s.isSet { - return ErrInvalidSession - } - - if err := s.manager.store.Clear(s.cookie.Value); err != nil { +// Destroy deletes the session from backend and clears the cookie. +func (s *Session) Destroy() error { + err := s.manager.store.Destroy(s.id) + if err != nil { return errAs(err) } - - return s.clearCookie() + s.ResetCache() + return s.ClearCookie() } -// Int is a helper to get values as integer +// Int is a helper to get values as integer. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) Int(r interface{}, err error) (int, error) { out, err := s.manager.store.Int(r, err) return out, errAs(err) } -// Int64 is a helper to get values as Int64 +// Int64 is a helper to get values as Int64. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) Int64(r interface{}, err error) (int64, error) { out, err := s.manager.store.Int64(r, err) return out, errAs(err) } -// UInt64 is a helper to get values as UInt64 +// UInt64 is a helper to get values as UInt64. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) UInt64(r interface{}, err error) (uint64, error) { out, err := s.manager.store.UInt64(r, err) return out, errAs(err) } -// Float64 is a helper to get values as Float64 +// Float64 is a helper to get values as Float64. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) Float64(r interface{}, err error) (float64, error) { out, err := s.manager.store.Float64(r, err) return out, errAs(err) } -// String is a helper to get values as String +// String is a helper to get values as String. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) String(r interface{}, err error) (string, error) { out, err := s.manager.store.String(r, err) return out, errAs(err) } -// Bytes is a helper to get values as Bytes +// Bytes is a helper to get values as Bytes. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) Bytes(r interface{}, err error) ([]byte, error) { out, err := s.manager.store.Bytes(r, err) return out, errAs(err) } -// Bool is a helper to get values as Bool +// Bool is a helper to get values as Bool. +// If the value is Nil, ErrNil is returned, which means key doesn't exist. func (s *Session) Bool(r interface{}, err error) (bool, error) { out, err := s.manager.store.Bool(r, err) return out, errAs(err) @@ -370,11 +336,9 @@ func errAs(err error) error { case 1: return ErrInvalidSession case 2: - return ErrFieldNotFound + return ErrNil case 3: return ErrAssertType - case 4: - return ErrNil } return err diff --git a/session_test.go b/session_test.go index 6a130f1..119d387 100644 --- a/session_test.go +++ b/session_test.go @@ -2,6 +2,7 @@ package simplesessions import ( "errors" + "fmt" "net/http" "testing" "time" @@ -9,37 +10,36 @@ import ( "github.com/stretchr/testify/assert" ) -var ( - testCookieName = "sometestcookie" - testCookieValue = "sometestcookievalue" -) - -func newMockStore() *MockStore { - return &MockStore{} +type Err struct { + code int + msg string } -func newMockManager(store *MockStore) *Manager { - mockManager := New(Options{}) - mockManager.UseStore(store) - mockManager.RegisterGetCookie(getCookieCb) - mockManager.RegisterSetCookie(setCookieCb) - - return mockManager +func (e *Err) Error() string { + return e.msg } -func getCookieCb(name string, r interface{}) (*http.Cookie, error) { - return &http.Cookie{ - Name: name, - Value: testCookieValue, - }, nil +func (e *Err) Code() int { + return e.code } -func setCookieCb(*http.Cookie, interface{}) error { - return nil +func TestErrorTypes(t *testing.T) { + var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + errInvalidSession = &Err{code: 1, msg: "invalid session"} + errNil = &Err{code: 2, msg: "nil returned"} + errAssertType = &Err{code: 3, msg: "assertion failed"} + errCustom = &Err{msg: "custom error"} + ) + + assert.Equal(t, errAs(errInvalidSession), ErrInvalidSession) + assert.Equal(t, errAs(errAssertType), ErrAssertType) + assert.Equal(t, errAs(errNil), ErrNil) + assert.Equal(t, errAs(errCustom), errCustom) } -func TestSessionHelpers(t *testing.T) { - assert := assert.New(t) +func TestHelpers(t *testing.T) { sess := Session{ manager: newMockManager(newMockStore()), } @@ -47,653 +47,510 @@ func TestSessionHelpers(t *testing.T) { // Int var inp1 = 100 v1, err := sess.Int(inp1, errors.New("test error")) - assert.Equal(v1, inp1) - assert.Error(err, "test error") + assert.Equal(t, inp1, v1) + assert.Equal(t, "test error", err.Error()) // Int64 var inp2 int64 = 100 v2, err := sess.Int64(inp2, errors.New("test error")) - assert.Equal(v2, inp2) - assert.Error(err, "test error") + assert.Equal(t, inp2, v2) + assert.Equal(t, "test error", err.Error()) var inp3 uint64 = 100 v3, err := sess.UInt64(inp3, errors.New("test error")) - assert.Equal(v3, inp3) - assert.Error(err, "test error") + assert.Equal(t, inp3, v3) + assert.Equal(t, "test error", err.Error()) var inp4 float64 = 100 v4, err := sess.Float64(inp4, errors.New("test error")) - assert.Equal(v4, inp4) - assert.Error(err, "test error") + assert.Equal(t, inp4, v4) + assert.Equal(t, "test error", err.Error()) var inp5 = "abc123" v5, err := sess.String(inp5, errors.New("test error")) - assert.Equal(v5, inp5) - assert.Error(err, "test error") + assert.Equal(t, inp5, v5) + assert.Equal(t, "test error", err.Error()) var inp6 = true v6, err := sess.Bool(inp6, errors.New("test error")) - assert.Equal(v6, inp6) - assert.Error(err, "test error") + assert.Equal(t, inp6, v6) + assert.Equal(t, "test error", err.Error()) var inp7 = []byte{} v7, err := sess.Bytes(inp7, errors.New("test error")) - assert.Equal(v7, inp7) - assert.Error(err, "test error") + assert.Equal(t, inp7, v7) + assert.Equal(t, "test error", err.Error()) } -func TestSessionNewSession(t *testing.T) { +func TestNewSession(t *testing.T) { reader := "some reader" writer := "some writer" - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) + mgr := newMockManager(newMockStore()) - assert := assert.New(t) - sess, err := NewSession(mockManager, reader, writer) - assert.NoError(err) - assert.Equal(sess.manager, mockManager) - assert.Equal(sess.reader, reader) - assert.Equal(sess.writer, writer) - assert.NotNil(sess.values) - assert.NotNil(sess.cookie) - assert.Equal(sess.cookie.Name, defaultCookieName) - assert.Equal(sess.cookie.Value, testCookieValue) - assert.True(sess.isSet) + sess, err := mgr.NewSession(reader, writer) + assert.NoError(t, err) + assert.Equal(t, mgr, sess.manager) + assert.Equal(t, reader, sess.reader) + assert.Equal(t, writer, sess.writer) + assert.Nil(t, sess.cache) + assert.Equal(t, sess.id, sess.ID()) } -func TestSessionNewSessionErrorStoreCreate(t *testing.T) { +func TestNewSessionErrors(t *testing.T) { assert := assert.New(t) - mockStore := newMockStore() - mockStore.isValid = true - - testError := errors.New("this is test error") - newCookieVal := "somerandomid" - mockStore.val = newCookieVal - mockStore.err = testError - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - sess, err := NewSession(mockManager, nil, nil) - assert.Error(err, testError.Error()) + mgr := New(Options{}) + sess, err := mgr.NewSession(nil, nil) + assert.Equal("session store not set", err.Error()) assert.Nil(sess) -} - -func TestSessionNewSessionErrorWriteCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockStore.isValid = true - - testError := errors.New("this is test error") - newCookieVal := "somerandomid" - mockStore.val = newCookieVal - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - return testError - }) - sess, err := NewSession(mockManager, nil, nil) - assert.Error(err, testError.Error()) + mgr = New(Options{}) + mgr.UseStore(&MockStore{}) + sess, err = mgr.NewSession(nil, nil) + assert.Equal("`SetCookie` hook not set", err.Error()) assert.Nil(sess) -} -func TestSessionNewSessionInvalidGetCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - testError := errors.New("custom error") - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, testError - }) - - sess, err := NewSession(mockManager, nil, nil) - assert.Error(err, testError.Error()) + // Store error. + tErr := errors.New("store error") + str := newMockStore() + str.err = tErr + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.ErrorIs(tErr, err) assert.Nil(sess) -} - -func TestSessionNewSessionCreateNewCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - newCookieVal := "somerandomid" - mockStore.val = newCookieVal - mockStore.isValid = true - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - assert.True(sess.isSet) - assert.Equal(sess.cookie.Value, newCookieVal) -} -func TestSessionNewSessionWithDisableAutoSet(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() + // Cookie write error. + str.err = nil + wErr := errors.New("write cookie error") + mgr.SetCookieHooks(nil, func(*http.Cookie, interface{}) error { return wErr }) - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) + sess, err = mgr.NewSession(nil, nil) + assert.ErrorIs(wErr, err) + assert.Nil(sess) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - assert.False(sess.isSet) + genErr := fmt.Errorf("generate error") + gen := func() (string, error) { return "xxx", genErr } + validate := func(string) bool { return false } + mgr.SetSessionIDHooks(gen, validate) + sess, err = mgr.NewSession(nil, nil) + assert.ErrorIs(genErr, err) + assert.Nil(sess) } -func TestSessionNewSessionGetCookieCb(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - // Calls write cookie callback if cookie is not set already - newCookieVal := "somerandomid" - mockStore.val = newCookieVal - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - var receivedName string - var receivedReader interface{} - var isCallbackTriggered bool - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - isCallbackTriggered = true - receivedName = name - receivedReader = r - return nil, http.ErrNoCookie - }) - - var reader = "this is reader interface" - sess, err := NewSession(mockManager, reader, nil) - assert.NoError(err) - assert.True(sess.isSet) - assert.True(isCallbackTriggered) - assert.Equal(receivedName, mockManager.opts.CookieName) - assert.Equal(receivedReader, reader) +func TestNewSessionCreateNewCookie(t *testing.T) { + mgr := newMockManager(newMockStore()) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + assert.True(t, mgr.validateID(sess.id)) } -func TestSessionNewSessionSetCookieCb(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - - // Calls write cookie callback if cookie is not set already - newCookieVal := "somerandomid" - mockStore.val = newCookieVal - mockStore.isValid = true - mockManager := newMockManager(mockStore) - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) +func TestNewSessionSetCookieCb(t *testing.T) { + var ( + mgr = newMockManager(newMockStore()) + receCk *http.Cookie + receWr interface{} + isCb bool + ) - var receivedCookie *http.Cookie - var receivedWriter interface{} - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookie = cookie - receivedWriter = w - isCallbackTriggered = true + mgr.SetCookieHooks(nil, func(ck *http.Cookie, w interface{}) error { + receCk = ck + receWr = w + isCb = true return nil }) var writer = "this is writer interface" - sess, err := NewSession(mockManager, nil, writer) - assert.NoError(err) - assert.True(sess.isSet) - assert.True(isCallbackTriggered) - assert.Equal(receivedCookie.Value, newCookieVal) - assert.Equal(receivedWriter, writer) -} - -func TestSessionWriteCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts = &Options{ - CookieName: "somename", - CookieDomain: "abc.xyz", - CookiePath: "/abc/xyz", - CookieLifetime: time.Second * 1000, - IsHTTPOnlyCookie: true, - IsSecureCookie: true, - DisableAutoSet: true, - SameSite: http.SameSiteDefaultMode, + sess, err := mgr.NewSession(nil, writer) + assert.NoError(t, err) + + assert.True(t, isCb) + assert.Equal(t, sess.id, receCk.Value) + assert.Equal(t, writer, receWr) +} + +func TestWriteCookie(t *testing.T) { + mgr := newMockManager(newMockStore()) + mgr.opts = &Options{ + EnableAutoCreate: false, + Cookie: CookieOptions{ + Name: "somename", + Domain: "abc.xyz", + Path: "/abc/xyz", + IsHTTPOnly: true, + IsSecure: true, + SameSite: http.SameSiteDefaultMode, + MaxAge: time.Hour, + Expires: time.Now(), + }, } - mockStore.isValid = true - - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - sess.WriteCookie("testvalue") - assert.Equal(sess.cookie.Name, mockManager.opts.CookieName) - assert.Equal(sess.cookie.Value, "testvalue") - assert.Equal(sess.cookie.Domain, mockManager.opts.CookieDomain) - assert.Equal(sess.cookie.Path, mockManager.opts.CookiePath) - assert.Equal(sess.cookie.Secure, mockManager.opts.IsSecureCookie) - assert.Equal(sess.cookie.SameSite, mockManager.opts.SameSite) - assert.Equal(sess.cookie.HttpOnly, mockManager.opts.IsHTTPOnlyCookie) - - // Ignore seconds - expiry := time.Now().Add(mockManager.opts.CookieLifetime) - assert.Equal(sess.cookie.Expires.Format("2006-01-02 15:04:05"), expiry.Format("2006-01-02 15:04:05")) - assert.WithinDuration(expiry, sess.cookie.Expires, time.Millisecond*1000) -} -func TestSessionClearCookie(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - - var receivedCookie *http.Cookie - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - receivedCookie = cookie - isCallbackTriggered = true + var receCk *http.Cookie + mgr.SetCookieHooks(nil, func(ck *http.Cookie, w interface{}) error { + receCk = ck return nil }) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + assert.NoError(t, sess.WriteCookie("testvalue")) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.clearCookie() - assert.NoError(err) - - assert.True(isCallbackTriggered) - assert.Equal(receivedCookie.Value, "") - assert.True(receivedCookie.Expires.UnixNano() < time.Now().UnixNano()) + assert.Equal(t, mgr.opts.Cookie.Name, receCk.Name) + assert.Equal(t, mgr.opts.Cookie.Domain, receCk.Domain) + assert.Equal(t, mgr.opts.Cookie.Path, receCk.Path) + assert.Equal(t, mgr.opts.Cookie.IsSecure, receCk.Secure) + assert.Equal(t, mgr.opts.Cookie.SameSite, receCk.SameSite) + assert.Equal(t, mgr.opts.Cookie.IsHTTPOnly, receCk.HttpOnly) + assert.Equal(t, int(mgr.opts.Cookie.MaxAge.Seconds()), receCk.MaxAge) + assert.Equal(t, mgr.opts.Cookie.Expires, receCk.Expires) } -func TestSessionCreate(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = "test" - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - isCallbackTriggered = true +func TestClearCookie(t *testing.T) { + var ( + mgr = newMockManager(newMockStore()) + receCk *http.Cookie + isCb bool + ) + mgr.SetCookieHooks(nil, func(ck *http.Cookie, w interface{}) error { + receCk = ck + isCb = true return nil }) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - assert.False(sess.isSet) - assert.False(isCallbackTriggered) - - err = sess.Create() - assert.NoError(err) - assert.True(isCallbackTriggered) - assert.True(sess.isSet) -} - -func TestSessionLoadValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.LoadValues() - assert.NoError(err) - assert.Contains(sess.values, "val") - assert.Equal(sess.values["val"], 100) -} - -func TestSessionResetValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) - err = sess.LoadValues() - assert.NoError(err) - assert.Contains(sess.values, "val") - assert.Equal(sess.values["val"], 100) + err = sess.ClearCookie() + assert.NoError(t, err) - sess.ResetValues() - assert.Equal(len(sess.values), 0) + assert.True(t, isCb) + assert.Equal(t, "", receCk.Value) + assert.True(t, receCk.Expires.UnixNano() < time.Now().UnixNano()) } -func TestSessionGetAllFromStore(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - vals, err := sess.GetAll() - assert.NoError(err) - assert.Contains(vals, "val") - assert.Equal(vals["val"], 100) -} - -func TestSessionGetAllLoadedValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - setVals := make(map[string]interface{}) - setVals["sample"] = "someval" - sess.values = setVals - - vals, err := sess.GetAll() - assert.NoError(err) - assert.Contains(vals, "sample") - assert.Equal(vals["sample"], "someval") -} - -func TestSessionGetAllInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - vals, err := sess.GetAll() - assert.Error(err, ErrInvalidSession.Error()) - assert.Nil(vals) -} - -func TestSessionGetMultiFromStore(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - vals, err := sess.GetMulti("val") - assert.NoError(err) - assert.Contains(vals, "val") - assert.Equal(vals["val"], 100) -} - -func TestSessionGetMultiLoadedValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - setVals := make(map[string]interface{}) - setVals["key1"] = "someval" - setVals["key2"] = "someval" - sess.values = setVals - - vals, err := sess.GetMulti("key1") - assert.NoError(err) - assert.Contains(vals, "key1") - assert.Equal(vals["key1"], "someval") - assert.NotContains(vals, "key2") -} - -func TestSessionGetMultiInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - vals, err := sess.GetMulti("val") - assert.Error(err, ErrInvalidSession.Error()) - assert.Nil(vals) -} - -func TestSessionGetFromStore(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockStore.val = 100 - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - val, err := sess.Get("val") - assert.NoError(err) - assert.Equal(val, 100) -} - -func TestSessionGetLoadedValues(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - setVals := make(map[string]interface{}) - setVals["key1"] = "someval1" - setVals["key2"] = "someval2" - sess.values = setVals - - val, err := sess.Get("key1") - assert.NoError(err) - assert.Equal(val, "someval1") -} - -func TestSessionGetInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - vals, err := sess.Get("val") - assert.Error(err, ErrInvalidSession.Error()) - assert.Nil(vals) -} - -func TestSessionSet(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.Set("key", 100) - assert.NoError(err) - assert.Equal(mockStore.val, 100) -} - -func TestSessionSetInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.Set("key", 100) - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestSessionCommit(t *testing.T) { - mockStore := newMockStore() - mockStore.isValid = true - mockManager := newMockManager(mockStore) +func TestCacheAll(t *testing.T) { + str := newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr := newMockManager(str) - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.Set("key", 100) - assert.NoError(err) - assert.NoError(err) - assert.False(mockStore.isCommited) - err = sess.Commit() - assert.NoError(err) - assert.True(mockStore.isCommited) -} + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) -func TestSessionCommitInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) + // Test error. + str.err = errors.New("store error") + err = sess.CacheAll() + assert.ErrorIs(t, str.err, err) + assert.Nil(t, sess.cache) - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - err = sess.Commit() - assert.Error(err, ErrInvalidSession.Error()) + // Test without error. + str.err = nil + err = sess.CacheAll() + assert.NoError(t, err) + assert.Equal(t, str.data, sess.cache) } -func TestSessionDelete(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - mockStore.val = 100 - - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - assert.Equal(mockStore.val, 100) - - err = sess.Delete("somekey") - assert.NoError(err) - assert.Nil(mockStore.val) - - testError := errors.New("this is test error") - mockStore.err = testError - err = sess.Delete("somekey") - assert.Error(err, testError.Error()) -} +func TestResetCache(t *testing.T) { + str := newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr := newMockManager(str) + sess, _ := mgr.NewSession(nil, nil) + sess.CacheAll() + assert.Equal(t, str.data, sess.cache) + + sess.ResetCache() + assert.Nil(t, sess.cache) +} + +func TestGetStore(t *testing.T) { + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } -func TestSessionClear(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - mockStore.val = 100 - - var isCallbackTriggered bool - mockManager.RegisterSetCookie(func(cookie *http.Cookie, w interface{}) error { - isCallbackTriggered = true - return nil - }) + // GetAll. + v1, err := sess.GetAll() + assert.NoError(t, err) + assert.Equal(t, str.data, v1) + + // Get Multi. + v2, err := sess.GetMulti("key1", "key2") + assert.NoError(t, err) + assert.Contains(t, v2, "key1") + assert.Equal(t, str.data["key1"], v2["key1"]) + assert.Contains(t, v2, "key2") + assert.Equal(t, str.data["key2"], v2["key2"]) + assert.NotContains(t, v2, "key3") + + // Get. + v3, err := sess.Get("key1") + assert.NoError(t, err) + assert.Contains(t, str.data, "key1") + assert.Equal(t, str.data["key1"], v3) +} + +func TestGetCached(t *testing.T) { + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + + sess.cache = map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - assert.Equal(mockStore.val, 100) + // GetAll. + v1, err := sess.GetAll() + assert.NoError(t, err) + assert.Equal(t, sess.cache, v1) + + // GetMulti. + v2, err := sess.GetMulti("key1", "key2") + assert.NoError(t, err) + assert.Contains(t, v2, "key1") + assert.Equal(t, sess.cache["key1"], v2["key1"]) + assert.Contains(t, v2, "key2") + assert.Equal(t, sess.cache["key2"], v2["key2"]) + assert.NotContains(t, v2, "key3") + + // Get. + v3, err := sess.Get("key1") + assert.NoError(t, err) + assert.Contains(t, sess.cache, "key1") + assert.Equal(t, sess.cache["key1"], v3) + + // Get unknowm field. + v3, err = sess.Get("key99") + assert.NoError(t, err) + assert.Nil(t, v3) + + // GetMulti unknown fields + v4, err := sess.GetMulti("key1", "key2", "key99", "key100") + assert.NoError(t, err) + assert.Contains(t, v4, "key1") + assert.Equal(t, sess.cache["key1"], v4["key1"]) + assert.Contains(t, v4, "key99") + assert.Contains(t, v4, "key100") + + v5, ok := v4["key99"] + assert.True(t, ok) + assert.Nil(t, v5) + + v5, ok = v4["key100"] + assert.True(t, ok) + assert.Nil(t, v5) +} + +func TestSet(t *testing.T) { + str := newMockStore() + str.data = map[string]interface{}{} + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + + err = sess.Set("key1", 1) + assert.NoError(t, err) + + // Check if its set on data after commit. + assert.Contains(t, str.data, "key1") + assert.Equal(t, 1, str.data["key1"]) + assert.Nil(t, sess.cache) + + // Cache and set. + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.Set("key1", 1) + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) +} + +func TestSetMulti(t *testing.T) { + str := newMockStore() + str.data = map[string]interface{}{} + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + + data := map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } + err = sess.SetMulti(data) + assert.NoError(t, err) + + // Check if its set on data after commit. + assert.Contains(t, str.data, "key1") + assert.Contains(t, str.data, "key2") + assert.Contains(t, str.data, "key3") + assert.Equal(t, data["key1"], str.data["key1"]) + assert.Equal(t, data["key2"], str.data["key2"]) + assert.Equal(t, data["key3"], str.data["key3"]) + assert.Nil(t, sess.cache) + + // Cache and set. + str.data = map[string]interface{}{} + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.SetMulti(data) + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) + + // Test error. + sess.ResetCache() + str.err = errors.New("store error") + err = sess.SetMulti(data) + assert.ErrorIs(t, str.err, err) +} + +func TestDelete(t *testing.T) { + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + "key3": 3, + } + assert.Contains(t, str.data, "key1") + err = sess.Delete("key1") + assert.NoError(t, err) + assert.NotContains(t, str.data, "key1") + + // Cache and set. + err = sess.CacheAll() + assert.NoError(t, err) + err = sess.Delete("key2") + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + assert.Equal(t, sess.cache, str.data) + + // Test error. + str.err = errors.New("store error") + err = sess.Delete("key2") + assert.ErrorIs(t, str.err, err) +} + +func TestClear(t *testing.T) { + // Test errors. + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + str.err = errors.New("store error") err = sess.Clear() - assert.NoError(err) + assert.ErrorIs(t, str.err, err) - assert.True(isCallbackTriggered) - assert.Equal(mockStore.val, nil) -} - -func TestSessionClearError(t *testing.T) { - assert := assert.New(t) - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockStore.isValid = true - - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - - testError := errors.New("this is test error") - mockStore.err = testError + // Test clear. + str = newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.NoError(t, err) err = sess.Clear() - assert.Error(err, testError.Error()) -} - -func TestSessionClearInvalidSession(t *testing.T) { - mockStore := newMockStore() - mockManager := newMockManager(mockStore) - mockManager.opts.DisableAutoSet = true - mockManager.RegisterGetCookie(func(name string, r interface{}) (*http.Cookie, error) { - return nil, http.ErrNoCookie - }) - - assert := assert.New(t) - sess, err := NewSession(mockManager, nil, nil) - assert.NoError(err) - + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) + + // Test clear. + str = newMockStore() + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + assert.NoError(t, err) + err = sess.CacheAll() + assert.NoError(t, err) + assert.NotNil(t, sess.cache) err = sess.Clear() - assert.Error(err, ErrInvalidSession.Error()) -} - -type Err struct { - code int - msg string -} - -func (e *Err) Error() string { - return e.msg -} - -func (e *Err) Code() int { - return e.code -} + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) +} + +func TestDestroy(t *testing.T) { + // Test errors. + str := newMockStore() + mgr := newMockManager(str) + sess, err := mgr.NewSession(nil, nil) + assert.NoError(t, err) + str.err = errors.New("store error") + err = sess.Destroy() + assert.ErrorIs(t, str.err, err) + + // Test cookie write error. + str.err = nil + ckErr := errors.New("cookie error") + mgr.SetCookieHooks(nil, func(*http.Cookie, interface{}) error { return ckErr }) + + str.data = map[string]interface{}{"foo": "bar"} + err = sess.Destroy() + assert.ErrorIs(t, ckErr, err) + + // Test clear. + str = newMockStore() + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + assert.NoError(t, err) + err = sess.Destroy() + assert.NoError(t, err) + assert.Nil(t, str.data) + assert.Nil(t, sess.cache) + + // Test clear. + str = newMockStore() + mgr = newMockManager(str) + sess, err = mgr.NewSession(nil, nil) + str.data = map[string]interface{}{ + "key1": 1, + "key2": 2, + } + assert.NoError(t, err) + err = sess.CacheAll() + assert.NoError(t, err) + assert.NotNil(t, sess.cache) + err = sess.Clear() + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.Nil(t, sess.cache) -func TestErrorTypes(t *testing.T) { + // Test deleteCookie callback. var ( - // Error codes for store errors. This should match the codes - // defined in the /simplesessions package exactly. - errInvalidSession = &Err{code: 1, msg: "invalid session"} - errFieldNotFound = &Err{code: 2, msg: "field not found"} - errAssertType = &Err{code: 3, msg: "assertion failed"} - errNil = &Err{code: 4, msg: "nil returned"} + receCk *http.Cookie + isCb bool ) - - assert.Equal(t, errAs(errInvalidSession), ErrInvalidSession) - assert.Equal(t, errAs(errFieldNotFound), ErrFieldNotFound) - assert.Equal(t, errAs(errAssertType), ErrAssertType) - assert.Equal(t, errAs(errNil), ErrNil) + mgr.SetCookieHooks(nil, func(ck *http.Cookie, w interface{}) error { + receCk = ck + isCb = true + return nil + }) + err = sess.Destroy() + assert.NoError(t, err) + assert.Equal(t, 0, len(str.data)) + assert.True(t, isCb) + assert.NotNil(t, receCk) + assert.Greater(t, time.Now(), receCk.Expires) } diff --git a/store.go b/store.go index 5f22791..6f66ce5 100644 --- a/store.go +++ b/store.go @@ -3,32 +3,41 @@ package simplesessions // Store represents store interface. This interface can be // implemented to create various backend stores for session. type Store interface { - // Create creates new session in store and returns the cookie value. - Create() (cookieValue string, err error) + // Create creates new session in the store for the given session ID. + Create(id string) (err error) - // Get gets a value for given key from session. - Get(cookieValue, key string) (value interface{}, err error) + // Get a value for the given key from session. + Get(id, key string) (value interface{}, err error) - // GetMulti gets a maps of multiple values for given keys. - GetMulti(cookieValue string, keys ...string) (values map[string]interface{}, err error) + // GetMulti gets a maps of multiple values for given keys from session. + // If some fields are not found then return nil for that field. + GetMulti(id string, keys ...string) (data map[string]interface{}, err error) - // GetAll gets all key and value from session, - GetAll(cookieValue string) (values map[string]interface{}, err error) + // GetAll gets all key and value from session. + GetAll(id string) (data map[string]interface{}, err error) // Set sets an value for a field in session. - // Its up to store to either store it in session right after set or after commit. - Set(cookieValue, key string, value interface{}) error + Set(id, key string, value interface{}) error - // Commit commits all the previously set values to store. - Commit(cookieValue string) error + // Set takes a map of kv pair and set the field in store. + SetMulti(id string, data map[string]interface{}) error - // Delete a field from session. - Delete(cookieValue string, key string) error + // Delete a given list of keys from session. + Delete(id string, key ...string) error - // Clear clears the session key from backend if exists. - Clear(cookieValue string) error + // Clear empties the session but doesn't delete it. + Clear(id string) error + + // Destroy deletes the entire session. + Destroy(id string) error // Helper method for typecasting/asserting. + // Supposed to be used as a chain. + // For example: sess.Int(sess.Get("id", "key")) + // Take `error` and returns that if its not nil. + // Take `interface{}` value and type assert or convert. + // If its nil then return ErrNil. + // If it can't type asserted/converted then return ErrAssertType. Int(interface{}, error) (int, error) Int64(interface{}, error) (int64, error) UInt64(interface{}, error) (uint64, error) diff --git a/store_test.go b/store_test.go index 4eba6fe..f74388d 100644 --- a/store_test.go +++ b/store_test.go @@ -2,58 +2,99 @@ package simplesessions // MockStore mocks the store for testing type MockStore struct { - isValid bool - cookieValue string - err error - val interface{} - isCommited bool + err error + id string + data map[string]interface{} } -func (s *MockStore) reset() { - s.isValid = false - s.cookieValue = "" - s.err = nil - s.val = nil - s.isCommited = false +func (s *MockStore) Create(id string) error { + s.data = make(map[string]interface{}) + return s.err } -func (s *MockStore) Create() (cv string, err error) { - return s.val.(string), s.err -} +func (s *MockStore) Get(id, key string) (interface{}, error) { + if s.id == "" || s.data == nil { + return nil, ErrInvalidSession + } -func (s *MockStore) Get(cv, key string) (value interface{}, err error) { - return s.val, s.err + d, ok := s.data[key] + if !ok { + return nil, nil + } + return d, s.err } -func (s *MockStore) GetMulti(cv string, keys ...string) (values map[string]interface{}, err error) { - vals := make(map[string]interface{}) - vals["val"] = s.val - return vals, s.err +func (s *MockStore) GetMulti(id string, keys ...string) (values map[string]interface{}, err error) { + if s.id == "" || s.data == nil { + return nil, ErrInvalidSession + } + + out := make(map[string]interface{}) + for _, key := range keys { + v, ok := s.data[key] + if !ok { + v = err + } + out[key] = v + } + + return out, s.err } -func (s *MockStore) GetAll(cv string) (values map[string]interface{}, err error) { - vals := make(map[string]interface{}) - vals["val"] = s.val - return vals, s.err +func (s *MockStore) GetAll(id string) (values map[string]interface{}, err error) { + if s.id == "" || s.data == nil { + return nil, ErrInvalidSession + } + + return s.data, s.err } func (s *MockStore) Set(cv, key string, value interface{}) error { - s.val = value + if s.id == "" || s.data == nil { + return ErrInvalidSession + } + + s.data[key] = value return s.err } -func (s *MockStore) Commit(cv string) error { - s.isCommited = true +func (s *MockStore) SetMulti(id string, data map[string]interface{}) error { + if s.id == "" || s.data == nil { + return ErrInvalidSession + } + + for k, v := range data { + s.data[k] = v + } return s.err } -func (s *MockStore) Delete(cv string, key string) error { - s.val = nil +func (s *MockStore) Delete(id string, key ...string) error { + if s.id == "" || s.data == nil { + return ErrInvalidSession + } + + for _, k := range key { + delete(s.data, k) + } + return s.err +} + +func (s *MockStore) Clear(id string) error { + if s.id == "" || s.data == nil { + return ErrInvalidSession + } + + s.data = make(map[string]interface{}) return s.err } -func (s *MockStore) Clear(cv string) error { - s.val = nil +func (s *MockStore) Destroy(id string) error { + if s.id == "" || s.data == nil { + return ErrInvalidSession + } + + s.data = nil return s.err } diff --git a/stores/goredis/go.mod b/stores/goredis/go.mod deleted file mode 100644 index b268fed..0000000 --- a/stores/goredis/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/vividvilla/simplesessions/stores/goredis/v9 - -go 1.18 - -require ( - github.com/alicebob/miniredis/v2 v2.32.1 - github.com/redis/go-redis/v9 v9.5.1 - github.com/stretchr/testify v1.9.0 - github.com/vividvilla/simplesessions/conv v1.0.0 -) - -require ( - github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/yuin/gopher-lua v1.1.1 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) diff --git a/stores/goredis/store.go b/stores/goredis/store.go deleted file mode 100644 index 2e4de30..0000000 --- a/stores/goredis/store.go +++ /dev/null @@ -1,298 +0,0 @@ -package goredis - -import ( - "context" - "crypto/rand" - "sync" - "time" - "unicode" - - "github.com/redis/go-redis/v9" - "github.com/vividvilla/simplesessions/conv" -) - -var ( - // Error codes for store errors. This should match the codes - // defined in the /simplesessions package exactly. - ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} - ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} -) - -type Err struct { - code int - msg string -} - -func (e *Err) Error() string { - return e.msg -} - -func (e *Err) Code() int { - return e.code -} - -// Store represents redis session store for simple sessions. -// Each session is stored as redis hashmap. -type Store struct { - // Maximum lifetime sessions has to be persisted. - ttl time.Duration - - // Prefix for session id. - prefix string - - // Temp map to store values before commit. - tempSetMap map[string]map[string]interface{} - mu sync.RWMutex - - // Redis client - client redis.UniversalClient - clientCtx context.Context -} - -const ( - // Default prefix used to store session redis - defaultPrefix = "session:" - sessionIDLen = 32 -) - -// New creates a new Redis store instance. -func New(ctx context.Context, client redis.UniversalClient) *Store { - return &Store{ - clientCtx: ctx, - client: client, - prefix: defaultPrefix, - tempSetMap: make(map[string]map[string]interface{}), - } -} - -// SetPrefix sets session id prefix in backend -func (s *Store) SetPrefix(val string) { - s.prefix = val -} - -// SetTTL sets TTL for session in redis. -func (s *Store) SetTTL(d time.Duration) { - s.ttl = d -} - -// Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created. -func (s *Store) Create() (string, error) { - id, err := generateID(sessionIDLen) - if err != nil { - return "", err - } - - return id, err -} - -// Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised -func (s *Store) Get(id, key string) (interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - v, err := s.client.HGet(s.clientCtx, s.prefix+id, key).Result() - if err == redis.Nil { - return nil, ErrFieldNotFound - } - - return v, err -} - -// GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. -func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - v, err := s.client.HMGet(s.clientCtx, s.prefix+id, keys...).Result() - // If field is not found then return map with fields as nil - if len(v) == 0 || err == redis.Nil { - v = make([]interface{}, len(keys)) - } - - // Form a map with returned results - res := make(map[string]interface{}) - for i, k := range keys { - res[k] = v[i] - } - - return res, err -} - -// GetAll gets all fields from hashmap. -func (s *Store) GetAll(id string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - res, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result() - if res == nil || err == redis.Nil { - return map[string]interface{}{}, nil - } else if err != nil { - return nil, err - } - - // Convert results to type `map[string]interface{}` - out := make(map[string]interface{}, len(res)) - for k, v := range res { - out[k] = v - } - - return out, nil -} - -// Set sets a value to given session but stored only on commit -func (s *Store) Set(id, key string, val interface{}) error { - if !validateID(id) { - return ErrInvalidSession - } - - s.mu.Lock() - defer s.mu.Unlock() - - // Create session map if doesn't exist - if _, ok := s.tempSetMap[id]; !ok { - s.tempSetMap[id] = make(map[string]interface{}) - } - - // set value to map - s.tempSetMap[id][key] = val - - return nil -} - -// Commit sets all set values. -func (s *Store) Commit(id string) error { - if !validateID(id) { - return ErrInvalidSession - } - - s.mu.RLock() - vals, ok := s.tempSetMap[id] - if !ok { - // Nothing to commit - s.mu.RUnlock() - return nil - } - - // Make slice of arguments to be passed in HGETALL command - args := make([]interface{}, len(vals)*2, len(vals)*2) - c := 0 - for k, v := range s.tempSetMap[id] { - args[c] = k - args[c+1] = v - c += 2 - } - s.mu.RUnlock() - - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() - - pipe := s.client.TxPipeline() - pipe.HMSet(s.clientCtx, s.prefix+id, args...) - // Set expiry of key only if 'ttl' is set, this is to - // ensure that the key remains valid indefinitely like - // how redis handles it by default - if s.ttl > 0 { - pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) - } - - _, err := pipe.Exec(s.clientCtx) - return err -} - -// Delete deletes a key from redis session hashmap. -func (s *Store) Delete(id string, key string) error { - if !validateID(id) { - return ErrInvalidSession - } - - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() - - err := s.client.HDel(s.clientCtx, s.prefix+id, key).Err() - if err == redis.Nil { - return ErrFieldNotFound - } - return err -} - -// Clear clears session in redis. -func (s *Store) Clear(id string) error { - if !validateID(id) { - return ErrInvalidSession - } - - return s.client.Del(s.clientCtx, s.prefix+id).Err() -} - -// Int returns redis reply as integer. -func (s *Store) Int(r interface{}, err error) (int, error) { - return conv.Int(r, err) -} - -// Int64 returns redis reply as Int64. -func (s *Store) Int64(r interface{}, err error) (int64, error) { - return conv.Int64(r, err) -} - -// UInt64 returns redis reply as UInt64. -func (s *Store) UInt64(r interface{}, err error) (uint64, error) { - return conv.UInt64(r, err) -} - -// Float64 returns redis reply as Float64. -func (s *Store) Float64(r interface{}, err error) (float64, error) { - return conv.Float64(r, err) -} - -// String returns redis reply as String. -func (s *Store) String(r interface{}, err error) (string, error) { - return conv.String(r, err) -} - -// Bytes returns redis reply as Bytes. -func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { - return conv.Bytes(r, err) -} - -// Bool returns redis reply as Bool. -func (s *Store) Bool(r interface{}, err error) (bool, error) { - return conv.Bool(r, err) -} - -func validateID(id string) bool { - if len(id) != sessionIDLen { - return false - } - - for _, r := range id { - if !unicode.IsDigit(r) && !unicode.IsLetter(r) { - return false - } - } - - return true -} - -// generateID generates a random alpha-num session ID. -func generateID(n int) (string, error) { - const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - bytes := make([]byte, n) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - for k, v := range bytes { - bytes[k] = dict[v%byte(len(dict))] - } - - return string(bytes), nil -} diff --git a/stores/goredis/store_test.go b/stores/goredis/store_test.go deleted file mode 100644 index 0f51d94..0000000 --- a/stores/goredis/store_test.go +++ /dev/null @@ -1,491 +0,0 @@ -package goredis - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" - "github.com/stretchr/testify/assert" -) - -var ( - mockRedis *miniredis.Miniredis -) - -func init() { - var err error - mockRedis, err = miniredis.Run() - if err != nil { - panic(err) - } -} - -func getRedisClient() redis.UniversalClient { - return redis.NewClient(&redis.Options{ - Addr: mockRedis.Addr(), - }) -} - -func TestNew(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - ctx := context.Background() - str := New(ctx, client) - assert.Equal(str.prefix, defaultPrefix) - assert.Equal(str.client, client) - assert.Equal(str.clientCtx, ctx) - assert.NotNil(str.tempSetMap) -} - -func TestSetPrefix(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - str.SetPrefix("test") - assert.Equal(str.prefix, "test") -} - -func TestSetTTL(t *testing.T) { - assert := assert.New(t) - testDur := time.Second * 10 - str := New(context.TODO(), getRedisClient()) - str.SetTTL(testDur) - assert.Equal(str.ttl, testDur) -} - -func TestCreate(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - id, err := str.Create() - assert.Nil(err) - assert.Equal(len(id), sessionIDLen) -} - -func TestGetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession) -} - -func TestGet(t *testing.T) { - assert := assert.New(t) - key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somekey" - value := 100 - client := getRedisClient() - - // Set a key - err := client.HSet(context.TODO(), defaultPrefix+key, field, value).Err() - assert.NoError(err) - - str := New(context.TODO(), client) - - val, err := str.Int(str.Get(key, field)) - assert.NoError(err) - assert.Equal(val, value) -} - -func TestGetFieldNotFoundError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(key, "invalidkey") - assert.Nil(val) - assert.Error(err, ErrFieldNotFound.Error()) -} - -func TestGetMultiInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestGetMultiFieldEmptySession(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somefield" - _, err := str.GetMulti(key, field) - assert.Nil(err) -} - -func TestGetMulti(t *testing.T) { - assert := assert.New(t) - key := "5dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - client := getRedisClient() - - // Set a key - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() - assert.NoError(err) - - str := New(context.TODO(), client) - - vals, err := str.GetMulti(key, field1, field2) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.NotContains(vals, field3) - - val1, err := str.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) - - val2, err := str.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) -} - -func TestGetAllInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - val, err := str.GetAll("invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestGetAll(t *testing.T) { - assert := assert.New(t) - key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - client := getRedisClient() - - // Set a key - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() - assert.NoError(err) - - str := New(context.TODO(), client) - - vals, err := str.GetAll(key) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.Contains(vals, field3) - - val1, err := str.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) - - val2, err := str.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) - - val3, err := str.Float64(vals[field3], nil) - assert.NoError(err) - assert.Equal(val3, value3) -} - -func TestSetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - err := str.Set("invalidid", "key", "value") - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestSet(t *testing.T) { - // Test should only set in internal map and not in redis - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field := "somekey" - value := 100 - - assert.NotNil(str.tempSetMap) - assert.NotContains(str.tempSetMap, key) - - err := str.Set(key, field, value) - assert.NoError(err) - assert.Contains(str.tempSetMap, key) - assert.Contains(str.tempSetMap[key], field) - assert.Equal(str.tempSetMap[key][field], value) - - // Check ifs not commited to redis - val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.Equal(val, int64(0)) -} - -func TestCommitInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - err := str.Commit("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestEmptyCommit(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - err := str.Commit("15IHy6S2uBuKaNnTUszB2180898ikGY1") - assert.NoError(err) -} - -func TestCommit(t *testing.T) { - // Test should commit in redis with expiry on key - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - str.SetTTL(10 * time.Second) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := str.Set(key, field1, value1) - assert.NoError(err) - - err = str.Set(key, field2, value2) - assert.NoError(err) - - err = str.Commit(key) - assert.NoError(err) - - vals, err := client.HGetAll(context.TODO(), defaultPrefix+key).Result() - assert.Equal(2, len(vals)) - - ttl, err := client.TTL(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.Equal(true, ttl.Seconds() > 0 && ttl.Seconds() <= 10) -} - -func TestDeleteInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - err := str.Delete("invalidkey", "somefield") - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestDelete(t *testing.T) { - // Test should only set in internal map and not in redis - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() - assert.NoError(err) - - err = str.Delete(key, field1) - assert.NoError(err) - - val, err := client.HExists(context.TODO(), defaultPrefix+key, field1).Result() - assert.False(val) - - val, err = client.HExists(context.TODO(), defaultPrefix+key, field2).Result() - assert.True(val) -} - -func TestClearInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - err := str.Clear("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestClear(t *testing.T) { - // Test should only set in internal map and not in redis - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() - assert.NoError(err) - - // Check if its set - val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.NotEqual(val, int64(0)) - - err = str.Clear(key) - assert.NoError(err) - - val, err = client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.Equal(val, int64(0)) -} - -func TestInt(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - value := 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.Int(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.Int(value, testError) - assert.Error(err, testError.Error()) -} - -func TestInt64(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value int64 = 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.Int64(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.Int64(value, testError) - assert.Error(err, testError.Error()) -} - -func TestUInt64(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value uint64 = 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.UInt64(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.UInt64(value, testError) - assert.Error(err, testError.Error()) -} - -func TestFloat64(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value float64 = 100 - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.Float64(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.Float64(value, testError) - assert.Error(err, testError.Error()) -} - -func TestString(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - value := "abc123" - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.String(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.String(value, testError) - assert.Error(err, testError.Error()) -} - -func TestBytes(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - var value []byte = []byte("abc123") - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.Bytes(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.Bytes(value, testError) - assert.Error(err, testError.Error()) -} - -func TestBool(t *testing.T) { - assert := assert.New(t) - client := getRedisClient() - str := New(context.TODO(), client) - - field := "somekey" - value := true - - err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) - - val, err := str.Bool(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) - - testError := errors.New("test error") - val, err = str.Bool(value, testError) - assert.Error(err, testError.Error()) -} diff --git a/stores/memory/go.mod b/stores/memory/go.mod index b39b0e1..1936154 100644 --- a/stores/memory/go.mod +++ b/stores/memory/go.mod @@ -1,4 +1,4 @@ -module github.com/vividvilla/simplesessions/stores/memory/v2 +module github.com/vividvilla/simplesessions/stores/memory/v3 go 1.18 diff --git a/stores/memory/store.go b/stores/memory/store.go index 69e78d6..fe5625c 100644 --- a/stores/memory/store.go +++ b/stores/memory/store.go @@ -1,22 +1,15 @@ package memory import ( - "crypto/rand" "sync" - "unicode" -) - -const ( - sessionIDLen = 32 ) var ( // Error codes for store errors. This should match the codes // defined in the /simplesessions package exactly. ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrNil = &Err{code: 2, msg: "nil returned"} ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} ) type Err struct { @@ -50,34 +43,34 @@ func New() *Store { // Create creates a new session id and returns it. This doesn't create the session in // sessions map since memory can be saved by not storing empty sessions and system // can not be stressed by just creating new sessions -func (s *Store) Create() (string, error) { - id, err := generateID(sessionIDLen) - if err != nil { - return "", err +func (s *Store) Create(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Check if session already exists. + _, ok := s.sessions[id] + if ok { + return nil } - return id, err + s.sessions[id] = make(map[string]interface{}) + return nil } // Get gets a field in session func (s *Store) Get(id, key string) (interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - var val interface{} s.mu.RLock() - // Check if session exists before accessing key from it - v, ok := s.sessions[id] - if ok && v != nil { - val, ok = s.sessions[id][key] + defer s.mu.RUnlock() + + // Check if session exists before accessing key from it. + sess, ok := s.sessions[id] + if !ok { + return nil, ErrInvalidSession } - s.mu.RUnlock() - // If session doesn't exist or field doesn't exist then send field not found error - // since we don't add session to sessions map on session create - if !ok || v == nil { - return nil, ErrFieldNotFound + val, ok := sess[key] + if !ok { + return nil, nil } return val, nil @@ -85,102 +78,117 @@ func (s *Store) Get(id, key string) (interface{}, error) { // GetMulti gets a map for values for multiple keys. If key is not present in session then nil is returned. func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - s.mu.RLock() defer s.mu.RUnlock() - sVals, ok := s.sessions[id] - // If session not set then send a map with value for all keys is nil - if sVals == nil || !ok { - sVals = make(map[string]interface{}) + sess, ok := s.sessions[id] + if !ok { + return nil, ErrInvalidSession } - res := make(map[string]interface{}) + out := make(map[string]interface{}) for _, k := range keys { - v, ok := sVals[k] + v, ok := sess[k] if !ok { - res[k] = nil + out[k] = nil } else { - res[k] = v + out[k] = v } } - return res, nil + return out, nil } // GetAll gets all fields in session func (s *Store) GetAll(id string) (map[string]interface{}, error) { - if !validateID(id) { + s.mu.RLock() + defer s.mu.RUnlock() + + sess, ok := s.sessions[id] + if !ok { return nil, ErrInvalidSession } - s.mu.RLock() - defer s.mu.RUnlock() - vals := s.sessions[id] + // Copy the map. + out := make(map[string]interface{}) + for k, v := range sess { + out[k] = v + } - return vals, nil + return out, nil } -// Set sets a value to given session but stored only on commit +// Set sets a value to given session. func (s *Store) Set(id, key string, val interface{}) error { - if !validateID(id) { + s.mu.Lock() + defer s.mu.Unlock() + + _, ok := s.sessions[id] + if !ok { return ErrInvalidSession } + s.sessions[id][key] = val + return nil +} +// SetMulti sets multiple key value pair to given session. +func (s *Store) SetMulti(id string, data map[string]interface{}) error { s.mu.Lock() defer s.mu.Unlock() - // If session is not set previously then create empty map + _, ok := s.sessions[id] if !ok { - s.sessions[id] = make(map[string]interface{}) + return ErrInvalidSession } - s.sessions[id][key] = val - - return nil -} + for k, v := range data { + s.sessions[id][k] = v + } -// Commit does nothing here since Set sets the value. -func (s *Store) Commit(id string) error { return nil } // Delete deletes a key from session. -func (s *Store) Delete(id string, key string) error { - if !validateID(id) { - return ErrInvalidSession - } - +func (s *Store) Delete(id string, keys ...string) error { s.mu.Lock() defer s.mu.Unlock() _, ok := s.sessions[id] - if ok && s.sessions[id] != nil { - _, ok = s.sessions[id][key] - if ok { - delete(s.sessions[id], key) - } + if !ok { + return ErrInvalidSession + } + + for _, k := range keys { + delete(s.sessions[id], k) } return nil } -// Clear clears session in redis. +// Clear empties the session. func (s *Store) Clear(id string) error { - if !validateID(id) { + s.mu.Lock() + defer s.mu.Unlock() + + _, ok := s.sessions[id] + if !ok { return ErrInvalidSession } + s.sessions[id] = make(map[string]interface{}) + return nil +} + +// Destroy deletes the entire session. +func (s *Store) Destroy(id string) error { s.mu.Lock() defer s.mu.Unlock() _, ok := s.sessions[id] - if ok { - delete(s.sessions, id) + if !ok { + return ErrInvalidSession } + delete(s.sessions, id) return nil } @@ -282,32 +290,3 @@ func (s *Store) Bool(r interface{}, err error) (bool, error) { return v, err } - -func validateID(id string) bool { - if len(id) != sessionIDLen { - return false - } - - for _, r := range id { - if !unicode.IsDigit(r) && !unicode.IsLetter(r) { - return false - } - } - - return true -} - -// generateID generates a random alpha-num session ID. -func generateID(n int) (string, error) { - const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - bytes := make([]byte, n) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - for k, v := range bytes { - bytes[k] = dict[v%byte(len(dict))] - } - - return string(bytes), nil -} diff --git a/stores/memory/store_test.go b/stores/memory/store_test.go index 1149189..8e0b01a 100644 --- a/stores/memory/store_test.go +++ b/stores/memory/store_test.go @@ -13,382 +13,333 @@ func TestNew(t *testing.T) { assert.NotNil(str.sessions) } -func TestIsValidSessionID(t *testing.T) { - assert := assert.New(t) - - // Not valid since length doesn't match - testString := "abc123" - assert.NotEqual(len(testString), sessionIDLen) - assert.False(validateID(testString)) - - // Not valid since length is same but not alpha numeric - invalidTestString := "0dIHy6S2uBuKaNnTUszB218L898ikGY$" - assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(validateID(invalidTestString)) - - // Valid - validTestString := "1dIHy6S2uBuKaNnTUszB218L898ikGY1" - assert.Equal(len(validTestString), sessionIDLen) - assert.True(validateID(validTestString)) -} - -func TestIsValid(t *testing.T) { - assert := assert.New(t) - - // Not valid since length doesn't match - testString := "abc123" - assert.NotEqual(len(testString), sessionIDLen) - assert.False(validateID(testString)) - - // Not valid since length is same but not alpha numeric - invalidTestString := "2dIHy6S2uBuKaNnTUszB218L898ikGY$" - assert.Equal(len(invalidTestString), sessionIDLen) - assert.False(validateID(invalidTestString)) - - // Valid - validTestString := "3dIHy6S2uBuKaNnTUszB218L898ikGY1" - assert.Equal(len(validTestString), sessionIDLen) - assert.True(validateID(validTestString)) -} - func TestCreate(t *testing.T) { - assert := assert.New(t) - str := New() - - id, err := str.Create() - assert.Nil(err) - assert.Equal(len(id), sessionIDLen) - assert.True(validateID(id)) -} - -func TestGetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New() - - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + var ( + id = "testid" + str = New() + ) + assert.NotContains(t, str.sessions, id) + err := str.Create(id) + assert.NoError(t, err) + assert.NotNil(t, str.sessions, id) + + // Check if existing session is not overwritten on Create. + val := map[string]interface{}{"foo": "bar"} + str.sessions["existing_id"] = val + err = str.Create("existing_id") + assert.NoError(t, err) + assert.Equal(t, val, str.sessions["existing_id"]) } func TestGet(t *testing.T) { - assert := assert.New(t) - key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somekey" - value := 100 - - // Set a key - str := New() - - str.sessions[key] = make(map[string]interface{}) - str.sessions[key][field] = value - - val, err := str.Get(key, field) - assert.NoError(err) - assert.Equal(val, value) -} - -func TestGetFieldNotFoundError(t *testing.T) { - assert := assert.New(t) - str := New() - - key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(key, "invalidkey") - assert.Nil(val) - assert.Error(err, ErrFieldNotFound.Error()) -} - -func TestGetMultiInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New() - - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestGetMultiFieldEmptySession(t *testing.T) { - assert := assert.New(t) - str := New() - - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - _, err := str.GetMulti(key) - assert.Nil(err) + var ( + id = "testid" + field = "somekey" + value = 100 + str = New() + ) + + _, err := str.Get("invalidkey", "invalidkey") + assert.ErrorIs(t, ErrInvalidSession, err) + + str.sessions[id] = make(map[string]interface{}) + str.sessions[id][field] = value + + val, err := str.Get(id, field) + assert.NoError(t, err) + assert.Equal(t, val, value) + + val, err = str.Get(id, "invalid") + assert.NoError(t, err) + assert.Nil(t, val) } func TestGetMulti(t *testing.T) { - assert := assert.New(t) - key := "5dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - str := New() + _, err := str.GetMulti("invalidkey", "invalidkey1", "invalidkey2") + assert.ErrorIs(t, ErrInvalidSession, err) + + var ( + id = "testid" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + field3 = "thishouldntbethere" + ) // Set a key - str.sessions[key] = make(map[string]interface{}) - str.sessions[key][field1] = value1 - str.sessions[key][field2] = value2 - str.sessions[key][field3] = value3 + str.sessions[id] = make(map[string]interface{}) + str.sessions[id][field1] = value1 + str.sessions[id][field2] = value2 - vals, err := str.GetMulti(key, field1, field2) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.NotContains(vals, field3) + vals, err := str.GetMulti(id, field1, field2, field3) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Equal(t, value1, vals[field1]) - assert.NoError(err) - assert.Equal(vals[field1], value1) + assert.Contains(t, vals, field2) + assert.Equal(t, value2, vals[field2]) - assert.NoError(err) - assert.Equal(vals[field2], value2) -} - -func TestGetAllInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New() - - val, err := str.GetAll("invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + assert.Contains(t, vals, field3) + assert.Nil(t, vals[field3]) } func TestGetAll(t *testing.T) { - assert := assert.New(t) - key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - str := New() + _, err := str.GetAll("invalidkey") + assert.ErrorIs(t, ErrInvalidSession, err) + + var ( + key = "testid" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + field3 = "thishouldntbethere" + ) // Set a key str.sessions[key] = make(map[string]interface{}) str.sessions[key][field1] = value1 str.sessions[key][field2] = value2 - str.sessions[key][field3] = value3 vals, err := str.GetAll(key) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.Contains(vals, field3) - - assert.NoError(err) - assert.Equal(vals[field1], value1) - - assert.NoError(err) - assert.Equal(vals[field2], value2) - - assert.NoError(err) - assert.Equal(vals[field3], value3) -} - -func TestSetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New() + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.NotContains(t, vals, field3) - err := str.Set("invalidid", "key", "value") - assert.Error(err, ErrInvalidSession.Error()) + assert.Equal(t, value1, vals[field1]) + assert.Equal(t, value2, vals[field2]) } func TestSet(t *testing.T) { - // Test should only set in internal map and not in redis - assert := assert.New(t) str := New() - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field := "somekey" - value := 100 - - assert.NotContains(str.sessions, key) - - err := str.Set(key, field, value) - assert.NoError(err) - assert.Contains(str.sessions, key) - assert.Contains(str.sessions[key], field) - assert.Equal(str.sessions[key][field], value) + err := str.Set("invalidkey", "key", "val") + assert.ErrorIs(t, ErrInvalidSession, err) + + // this id is unique across all tests + var ( + id = "testid" + field = "somekey" + value = 100 + ) + assert.NotContains(t, str.sessions, id) + + str.sessions[id] = map[string]interface{}{ + field: value, + } + err = str.Set(id, field, value) + assert.NoError(t, err) + assert.Contains(t, str.sessions, id) + assert.Contains(t, str.sessions[id], field) + assert.Equal(t, value, str.sessions[id][field]) } -func TestCommit(t *testing.T) { - assert := assert.New(t) +func TestSetMulti(t *testing.T) { str := New() - - err := str.Commit("invalidkey") - assert.Nil(err) -} - -func TestDeleteInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New() - - err := str.Delete("invalidkey", "somekey") - assert.Error(err, ErrInvalidSession.Error()) + err := str.SetMulti("invalidkey", map[string]interface{}{"foo": "bar"}) + assert.ErrorIs(t, ErrInvalidSession, err) + + // this id is unique across all tests + var ( + id = "testid" + field1 = "somekey1" + value1 = 100 + field2 = "somekey2" + value2 = 100 + ) + str.sessions[id] = map[string]interface{}{} + err = str.SetMulti(id, map[string]interface{}{ + field1: value1, + field2: value2, + }) + assert.NoError(t, err) + assert.Contains(t, str.sessions, id) + assert.Contains(t, str.sessions[id], field1) + assert.Contains(t, str.sessions[id], field2) + assert.Equal(t, value1, str.sessions[id][field1]) + assert.Equal(t, value2, str.sessions[id][field2]) } func TestDelete(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) str := New() + err := str.Delete("invalidkey", "key") + assert.ErrorIs(t, ErrInvalidSession, err) // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somefield1" - field2 := "somefield2" + var ( + key = "8dIHy6S2uBuKaNnTUszB2180898ikGY1" + field1 = "somefield1" + field2 = "somefield2" + ) str.sessions[key] = make(map[string]interface{}) str.sessions[key][field1] = 10 str.sessions[key][field2] = 10 - err := str.Delete(key, field1) - assert.NoError(err) - assert.Contains(str.sessions[key], field2) - assert.NotContains(str.sessions[key], field1) + err = str.Delete(key, field1) + assert.NoError(t, err) + assert.Contains(t, str.sessions[key], field2) + assert.NotContains(t, str.sessions[key], field1) } -func TestClearInvalidSessionError(t *testing.T) { - assert := assert.New(t) +func TestClear(t *testing.T) { + // Test should only set in internal map and not in redis str := New() - err := str.Clear("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) + + // this id is unique across all tests + id := "test_id" + str.sessions[id] = make(map[string]interface{}) + + err = str.Clear(id) + assert.NoError(t, err) + assert.Contains(t, str.sessions, id) + assert.Equal(t, len(str.sessions[id]), 0) } -func TestClear(t *testing.T) { +func TestDestroy(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) str := New() + err := str.Destroy("invalidkey") + assert.ErrorIs(t, ErrInvalidSession, err) - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - str.sessions[key] = make(map[string]interface{}) + // this id is unique across all tests + id := "test_id" + str.sessions[id] = make(map[string]interface{}) - err := str.Clear(key) - assert.NoError(err) - assert.NotContains(str.sessions, key) + err = str.Destroy(id) + assert.NoError(t, err) + assert.NotContains(t, str.sessions, id) } func TestInt(t *testing.T) { - assert := assert.New(t) str := New() var want int = 10 v, err := str.Int(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.Int(want, testError) - assert.Equal(v, 0) - assert.Error(testError) + assert.Equal(t, v, 0) + assert.ErrorIs(t, testError, err) _, err = str.Int("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, ErrAssertType, err) } func TestInt64(t *testing.T) { - assert := assert.New(t) str := New() var want int64 = 10 v, err := str.Int64(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.Int64(want, testError) - assert.Error(testError) + assert.Equal(t, v, int64(0)) + assert.ErrorIs(t, testError, err) _, err = str.Int64("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, ErrAssertType, err) } func TestUInt64(t *testing.T) { - assert := assert.New(t) str := New() var want uint64 = 10 v, err := str.UInt64(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.UInt64(want, testError) - assert.Error(testError) + assert.Equal(t, v, uint64(0)) + assert.ErrorIs(t, testError, err) _, err = str.UInt64("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, ErrAssertType, err) } func TestFloat64(t *testing.T) { - assert := assert.New(t) str := New() var want float64 = 10 v, err := str.Float64(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.Float64(want, testError) - assert.Error(testError) + assert.Equal(t, v, float64(0)) + assert.ErrorIs(t, testError, err) _, err = str.Float64("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, ErrAssertType, err) } func TestString(t *testing.T) { - assert := assert.New(t) str := New() var want = "string" v, err := str.String(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.String(want, testError) - assert.Error(testError) + assert.Equal(t, v, "") + assert.ErrorIs(t, testError, err) _, err = str.String(123, nil) - assert.Error(ErrAssertType) + assert.Error(t, ErrAssertType, err) } func TestBytes(t *testing.T) { - assert := assert.New(t) str := New() var want = []byte("a") v, err := str.Bytes(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.Bytes(want, testError) - assert.Error(testError) + assert.Equal(t, v, []byte(nil)) + assert.ErrorIs(t, testError, err) _, err = str.Bytes("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, ErrAssertType, err) } func TestBool(t *testing.T) { - assert := assert.New(t) str := New() var want = true v, err := str.Bool(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.Bool(want, testError) - assert.Error(testError) + assert.Equal(t, v, false) + assert.ErrorIs(t, testError, err) _, err = str.Bool("string", nil) - assert.Error(ErrAssertType) + assert.Error(t, ErrAssertType, err) +} + +func TestError(t *testing.T) { + err := Err{ + code: 1, + msg: "test", + } + assert.Equal(t, 1, err.Code()) + assert.Equal(t, "test", err.Error()) } diff --git a/stores/postgres/go.mod b/stores/postgres/go.mod index 002772c..bb46336 100644 --- a/stores/postgres/go.mod +++ b/stores/postgres/go.mod @@ -1,5 +1,14 @@ -module github.com/vividvilla/simplesessions/stores/postgres +module github.com/vividvilla/simplesessions/stores/postgres/v3 -go 1.14 +go 1.18 -require github.com/lib/pq v1.10.9 +require ( + github.com/lib/pq v1.10.9 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/stores/postgres/postgres.go b/stores/postgres/postgres.go index 8a1e3c2..30740b4 100644 --- a/stores/postgres/postgres.go +++ b/stores/postgres/postgres.go @@ -10,15 +10,12 @@ CREATE INDEX idx_sessions ON sessions (id, created_at); */ import ( - "crypto/rand" "database/sql" "encoding/json" - "errors" "fmt" - "sync" "time" - "unicode" + "github.com/lib/pq" _ "github.com/lib/pq" ) @@ -26,9 +23,8 @@ var ( // Error codes for store errors. This should match the codes // defined in the /simplesessions package exactly. ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrNil = &Err{code: 2, msg: "nil returned"} ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} ) type Err struct { @@ -45,12 +41,13 @@ func (e *Err) Code() int { } type queries struct { - create *sql.Stmt - get *sql.Stmt - update *sql.Stmt - delete *sql.Stmt - clear *sql.Stmt - prune *sql.Stmt + create *sql.Stmt + get *sql.Stmt + update *sql.Stmt + delete *sql.Stmt + clear *sql.Stmt + prune *sql.Stmt + destroy *sql.Stmt } // Store represents redis session store for simple sessions. @@ -59,11 +56,6 @@ type Store struct { db *sql.DB opt Opt q *queries - - commitID string - tx *sql.Tx - stmt *sql.Stmt - sync.Mutex } type Opt struct { @@ -75,10 +67,6 @@ type Opt struct { CleanInterval time.Duration `json:"clean_interval"` } -const ( - sessionIDLen = 32 -) - // New creates a new Postgres store instance. func New(opt Opt, db *sql.DB) (*Store, error) { if opt.Table == "" { @@ -107,42 +95,24 @@ func New(opt Opt, db *sql.DB) (*Store, error) { } // Create creates a new session and returns the ID. -func (s *Store) Create() (string, error) { - id, err := generateID(sessionIDLen) - if err != nil { - return "", err - } - - if _, err := s.q.create.Exec(id); err != nil { - return "", err - } - return id, nil +func (s *Store) Create(id string) error { + _, err := s.q.create.Exec(id) + return err } // Get returns a single session field's value. func (s *Store) Get(id, key string) (interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - // Scan the whole JSON map out so that it can be unmarshalled, - // preserving the types. - var b []byte - if err := s.q.get.QueryRow(id, s.opt.TTL.Seconds()).Scan(&b); err != nil { + vals, err := s.GetAll(id) + if err != nil { if err == sql.ErrNoRows { return nil, ErrInvalidSession } return nil, err } - var mp map[string]interface{} - if err := json.Unmarshal(b, &mp); err != nil { - return nil, err - } - - v, ok := mp[key] + v, ok := vals[key] if !ok { - return nil, ErrFieldNotFound + return nil, nil } return v, nil @@ -150,10 +120,6 @@ func (s *Store) Get(id, key string) (interface{}, error) { // GetMulti gets a map for values for multiple keys. If a key doesn't exist, it returns ErrFieldNotFound. func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - vals, err := s.GetAll(id) if err != nil { return nil, err @@ -163,7 +129,7 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err for _, k := range keys { v, ok := vals[k] if !ok { - return nil, ErrFieldNotFound + return nil, nil } out[k] = v } @@ -173,10 +139,6 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err // GetAll returns the map of all keys in the session. func (s *Store) GetAll(id string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - var b []byte err := s.q.get.QueryRow(id, s.opt.TTL.Seconds()).Scan(&b) if err != nil { @@ -193,45 +155,13 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) { // Set sets a value to given session but is stored only on commit. func (s *Store) Set(id, key string, val interface{}) (err error) { - if !validateID(id) { - return ErrInvalidSession - } - b, err := json.Marshal(map[string]interface{}{key: val}) if err != nil { return err } - s.Lock() - defer func() { - if err == nil { - s.Unlock() - return - } - - if s.tx != nil { - s.tx.Rollback() - s.tx = nil - } - s.stmt = nil - - s.Unlock() - }() - - // If a transaction isn't set, set it. - if s.tx == nil { - tx, err := s.db.Begin() - if err != nil { - return err - } - - // Prepare the statement for executing SQL commands - s.tx = tx - s.stmt = tx.Stmt(s.q.update) - } - // Execute the query in the batch to be committed later. - res, err := s.stmt.Exec(id, json.RawMessage(b)) + res, err := s.q.update.Exec(id, json.RawMessage(b)) if err != nil { return err } @@ -245,48 +175,37 @@ func (s *Store) Set(id, key string, val interface{}) (err error) { return ErrInvalidSession } - s.commitID = id return err } -// Commit sets all set values -func (s *Store) Commit(id string) error { - if !validateID(id) { - return ErrInvalidSession +// Set sets a value to given session but is stored only on commit. +func (s *Store) SetMulti(id string, data map[string]interface{}) (err error) { + b, err := json.Marshal(data) + if err != nil { + return err } - s.Lock() - if s.commitID != id { - s.Unlock() - return ErrInvalidSession + // Execute the query in the batch to be committed later. + res, err := s.q.update.Exec(id, json.RawMessage(b)) + if err != nil { + return err } - - defer func() { - if s.stmt != nil { - s.stmt.Close() - } - s.tx = nil - s.stmt = nil - s.Unlock() - }() - - if s.tx == nil { - return errors.New("nothing to commit") + num, err := res.RowsAffected() + if err != nil { + return err } - if s.commitID != id { + + // No row was updated. The session didn't exist. + if num == 0 { return ErrInvalidSession } - return s.tx.Commit() + return err } // Delete deletes a key from redis session hashmap. -func (s *Store) Delete(id string, key string) error { - if !validateID(id) { - return ErrInvalidSession - } - - res, err := s.q.delete.Exec(id, key) +func (s *Store) Delete(id string, keys ...string) error { + res, err := s.q.delete.Exec(id, pq.Array(keys)) if err != nil { return err } @@ -306,11 +225,27 @@ func (s *Store) Delete(id string, key string) error { // Clear clears session in redis. func (s *Store) Clear(id string) error { - if !validateID(id) { + res, err := s.q.clear.Exec(id) + if err != nil { + return err + } + + num, err := res.RowsAffected() + if err != nil { + return err + } + + // No row was updated. The session didn't exist. + if num == 0 { return ErrInvalidSession } - res, err := s.q.clear.Exec(id) + return nil +} + +// Destroy deletes the entire session from backend. +func (s *Store) Destroy(id string) error { + res, err := s.q.destroy.Exec(id) if err != nil { return err } @@ -454,7 +389,7 @@ func (s *Store) prepareQueries() (*queries, error) { return nil, err } - q.delete, err = s.db.Prepare(fmt.Sprintf("UPDATE %s SET data = data - $2 WHERE id=$1", s.opt.Table)) + q.delete, err = s.db.Prepare(fmt.Sprintf("UPDATE %s SET data = data #- $2 WHERE id=$1", s.opt.Table)) if err != nil { return nil, err } @@ -469,34 +404,10 @@ func (s *Store) prepareQueries() (*queries, error) { return nil, err } - return q, err -} - -func validateID(id string) bool { - if len(id) != sessionIDLen { - return false - } - - for _, r := range id { - if !unicode.IsDigit(r) && !unicode.IsLetter(r) { - return false - } - } - - return true -} - -// generateID generates a random alpha-num session ID. -func generateID(n int) (string, error) { - const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - bytes := make([]byte, n) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - for k, v := range bytes { - bytes[k] = dict[v%byte(len(dict))] + q.destroy, err = s.db.Prepare(fmt.Sprintf("DELETE FROM %s WHERE id=$1", s.opt.Table)) + if err != nil { + return nil, err } - return string(bytes), nil + return q, err } diff --git a/stores/postgres/postgres_test.go b/stores/postgres/postgres_test.go index ed8a929..17976e3 100644 --- a/stores/postgres/postgres_test.go +++ b/stores/postgres/postgres_test.go @@ -3,7 +3,9 @@ package postgres // For this test to run, set env vars: PG_HOST, PG_PORT, PG_USER, PG_PASSWORD, PG_DB. import ( + "crypto/rand" "database/sql" + "errors" "fmt" "log" "os" @@ -17,11 +19,24 @@ import ( const testTable = "sessions" var ( - st *Store - db *sql.DB - randID, _ = generateID(sessionIDLen) + st *Store + db *sql.DB ) +func generateID() (string, error) { + const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + for k, v := range bytes { + bytes[k] = dict[v%byte(len(dict))] + } + + return string(bytes), nil +} + func init() { if os.Getenv("PG_HOST") == "" { fmt.Println("WARNING: Skiping DB test as database config isn't set in env vars.") @@ -47,36 +62,38 @@ func init() { } } +func TestNew(t *testing.T) { + s1, err := New(Opt{}, db) + assert.Nil(t, err) + assert.Equal(t, s1.opt.Table, "sessions") + assert.Equal(t, s1.opt.TTL, time.Hour*24) + + _, err = New(Opt{Table: "unknown"}, db) + assert.Error(t, err) +} + func TestCreate(t *testing.T) { - for n := 0; n < 5; n++ { - id, err := st.Create() - assert.NoError(t, err) - assert.NotEmpty(t, id) - } + id, _ := generateID() + err := st.Create(id) + assert.NoError(t, err) + + var data []byte + err = db.QueryRow(fmt.Sprintf("SELECT data FROM %s WHERE id=$1", testTable), id).Scan(&data) + assert.NoError(t, err) + assert.Equal(t, []byte("{}"), data) } -func TestSet(t *testing.T) { - assert.NotEmpty(t, randID) +func TestAll(t *testing.T) { + id, _ := generateID() - id, err := st.Create() + err := st.Create(id) assert.NoError(t, err) - assert.NotEmpty(t, id) assert.NoError(t, st.Set(id, "num", 123)) assert.NoError(t, st.Set(id, "float", 12.3)) assert.NoError(t, st.Set(id, "str", "hello 123")) assert.NoError(t, st.Set(id, "bool", true)) - // Commit invalid session. - assert.Error(t, st.Commit(randID), ErrInvalidSession) - - // Commit valid session. - assert.NoError(t, st.Commit(id)) - - // Commit without setting. - assert.Error(t, st.Commit(id)) - assert.Error(t, st.Commit(randID)) - // Get different types. v, err := st.Get(id, "num") assert.NoError(t, err) @@ -86,42 +103,91 @@ func TestSet(t *testing.T) { v, err := st.Int(st.Get(id, "num")) assert.NoError(t, err) assert.Equal(t, v, int(123)) + + _, err = st.Int("xxx", nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.Int("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { v, err := st.Int64(st.Get(id, "num")) assert.NoError(t, err) assert.Equal(t, v, int64(123)) + + _, err = st.Int64("xxx", nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.Int64("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { v, err := st.UInt64(st.Get(id, "num")) assert.NoError(t, err) assert.Equal(t, v, uint64(123)) + + _, err = st.UInt64("xxx", nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.UInt64("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { v, err := st.Float64(st.Get(id, "float")) assert.NoError(t, err) assert.Equal(t, v, float64(12.3)) + + _, err = st.Float64("xxx", nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.Float64("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { v, err := st.String(st.Get(id, "str")) assert.NoError(t, err) assert.Equal(t, v, "hello 123") + + _, err = st.String(1, nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.String("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { v, err := st.Bytes(st.Get(id, "str")) assert.NoError(t, err) assert.Equal(t, v, []byte("hello 123")) + + _, err = st.Bytes(1, nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.Bytes("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { v, err := st.Bool(st.Get(id, "bool")) assert.NoError(t, err) assert.Equal(t, v, true) + + _, err = st.Bool("xxx", nil) + assert.ErrorIs(t, err, ErrAssertType) + + cErr := errors.New("type error") + _, err = st.Bool("xxx", cErr) + assert.ErrorIs(t, err, cErr) } { @@ -137,8 +203,9 @@ func TestSet(t *testing.T) { } // Non-existent field. - _, err = st.Get(id, "xx") - assert.ErrorIs(t, err, ErrFieldNotFound) + v, err = st.Get(id, "xx") + assert.Nil(t, v) + assert.Nil(t, err) // Get multiple. mp, err := st.GetMulti(id, "num", "str", "bool") @@ -149,39 +216,62 @@ func TestSet(t *testing.T) { "bool": true, }) mp, err = st.GetMulti(id, "num", "str", "bool", "blah") - assert.ErrorIs(t, err, ErrFieldNotFound) + assert.Nil(t, mp["blah"]) + assert.Nil(t, err) // Add another key in a different commit. assert.NoError(t, st.Set(id, "num2", 456)) - assert.NoError(t, st.Commit(id)) + + assert.NoError(t, st.SetMulti(id, map[string]interface{}{ + "num10": 1, + "num11": 2, + })) v, err = st.Get(id, "num2") assert.NoError(t, err) assert.Equal(t, v, float64(456)) + v, err = st.Get(id, "num10") + assert.NoError(t, err) + assert.Equal(t, v, float64(1)) + + v, err = st.Get(id, "num11") + assert.NoError(t, err) + assert.Equal(t, v, float64(2)) + // Delete. assert.ErrorIs(t, st.Delete("blah", "num2"), ErrInvalidSession) assert.NoError(t, st.Delete(id, "num2")) v, err = st.Get(id, "num2") + assert.Nil(t, v) + assert.Nil(t, err) v, err = st.Get(id, "num3") - assert.Error(t, ErrFieldNotFound) + assert.Nil(t, v) + assert.Nil(t, err) // Clear. - assert.ErrorIs(t, st.Clear(randID), ErrInvalidSession) + assert.ErrorIs(t, st.Clear("unknow_id"), ErrInvalidSession) assert.NoError(t, st.Clear(id)) v, err = st.Get(id, "str") - assert.Error(t, err, ErrFieldNotFound) + assert.Nil(t, v) + assert.Nil(t, err) + + // Destroy. + assert.ErrorIs(t, st.Destroy("unknow_id"), ErrInvalidSession) + assert.NoError(t, st.Destroy(id)) + _, err = st.Get(id, "str") + assert.ErrorIs(t, err, ErrInvalidSession) } func TestPrune(t *testing.T) { + id, _ := generateID() + // Create a new session. - id, err := st.Create() + err := st.Create(id) assert.NoError(t, err) - assert.NotEmpty(t, id) // Set value. assert.NoError(t, st.Set(id, "str", "hello 123")) - assert.NoError(t, st.Commit(id)) // Get value and verify. v, err := st.Get(id, "str") @@ -197,10 +287,10 @@ func TestPrune(t *testing.T) { // Create one more session and immediately run prune. Except for this, // all previous sessions should be gone. - id, err = st.Create() + id, _ = generateID() + err = st.Create(id) assert.NoError(t, err) assert.NoError(t, st.Set(id, "str", "hello 123")) - assert.NoError(t, st.Commit(id)) // Run prune. All previously created sessions should be gone. assert.NoError(t, st.Prune()) @@ -216,3 +306,12 @@ func TestPrune(t *testing.T) { assert.Equal(t, v, "hello 123") } + +func TestError(t *testing.T) { + err := Err{ + code: 1, + msg: "test", + } + assert.Equal(t, 1, err.Code()) + assert.Equal(t, "test", err.Error()) +} diff --git a/stores/redis/go.mod b/stores/redis/go.mod index b15a078..3779bde 100644 --- a/stores/redis/go.mod +++ b/stores/redis/go.mod @@ -1,16 +1,18 @@ -module github.com/vividvilla/simplesessions/stores/redis/v2 +module github.com/vividvilla/simplesessions/stores/redis/v3 go 1.18 require ( github.com/alicebob/miniredis/v2 v2.32.1 - github.com/gomodule/redigo v1.9.2 + github.com/redis/go-redis/v9 v9.5.1 github.com/stretchr/testify v1.9.0 ) require ( github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/stores/redis/store.go b/stores/redis/store.go index 50977dd..086978f 100644 --- a/stores/redis/store.go +++ b/stores/redis/store.go @@ -1,22 +1,19 @@ package redis import ( - "crypto/rand" - "errors" - "sync" + "context" + "strconv" "time" - "unicode" - "github.com/gomodule/redigo/redis" + "github.com/redis/go-redis/v9" ) var ( // Error codes for store errors. This should match the codes // defined in the /simplesessions package exactly. ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} + ErrNil = &Err{code: 2, msg: "nil returned"} ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} ) type Err struct { @@ -37,30 +34,31 @@ func (e *Err) Code() int { type Store struct { // Maximum lifetime sessions has to be persisted. ttl time.Duration + // extend TTL on update. + extendTTL bool // Prefix for session id. prefix string - // Temp map to store values before commit. - tempSetMap map[string]map[string]interface{} - mu sync.RWMutex - - // Redis pool - pool *redis.Pool + // Redis client + client redis.UniversalClient + clientCtx context.Context } const ( // Default prefix used to store session redis defaultPrefix = "session:" - sessionIDLen = 32 + // Default key used when session is created. + // Its not possible to have empty map in Redis. + defaultSessKey = "_ss" ) // New creates a new Redis store instance. -func New(pool *redis.Pool) *Store { +func New(ctx context.Context, client redis.UniversalClient) *Store { return &Store{ - pool: pool, - prefix: defaultPrefix, - tempSetMap: make(map[string]map[string]interface{}), + clientCtx: ctx, + client: client, + prefix: defaultPrefix, } } @@ -70,63 +68,59 @@ func (s *Store) SetPrefix(val string) { } // SetTTL sets TTL for session in redis. -func (s *Store) SetTTL(d time.Duration) { +// if isExtend is true then ttl is updated on all set/setmulti. +// otherwise its set only on create(). +func (s *Store) SetTTL(d time.Duration, extend bool) { s.ttl = d + s.extendTTL = extend } // Create returns a new session id but doesn't stores it in redis since empty hashmap can't be created. -func (s *Store) Create() (string, error) { - id, err := generateID(sessionIDLen) - if err != nil { - return "", err +func (s *Store) Create(id string) error { + // Create the session in backend with default session key since + // Redis doesn't support empty hashmap and its impossible to + // check if the session exist or not. + p := s.client.TxPipeline() + p.HSet(s.clientCtx, s.prefix+id, defaultSessKey, "1") + if s.ttl > 0 { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - - return id, err + _, err := p.Exec(s.clientCtx) + return err } // Get gets a field in hashmap. If field is nill then ErrFieldNotFound is raised func (s *Store) Get(id, key string) (interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession + vals, err := s.client.HMGet(s.clientCtx, s.prefix+id, defaultSessKey, key).Result() + if err != nil { + return nil, err } - conn := s.pool.Get() - defer conn.Close() - - v, err := conn.Do("HGET", s.prefix+id, key) - if v == nil || err == redis.ErrNil { - return nil, ErrFieldNotFound + if vals[0] == nil { + return nil, ErrInvalidSession } - return v, err + return vals[1], nil } // GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession - } - - conn := s.pool.Get() - defer conn.Close() - - // Make list of args for HMGET - args := make([]interface{}, len(keys)+1) - args[0] = s.prefix + id - for i := range keys { - args[i+1] = keys[i] + allKeys := append([]string{defaultSessKey}, keys...) + vals, err := s.client.HMGet(s.clientCtx, s.prefix+id, allKeys...).Result() + if err != nil { + return nil, err } - v, err := redis.Values(conn.Do("HMGET", args...)) - // If field is not found then return map with fields as nil - if len(v) == 0 || err == redis.ErrNil { - v = make([]interface{}, len(keys)) + if vals[0] == nil { + return nil, ErrInvalidSession } // Form a map with returned results res := make(map[string]interface{}) - for i, k := range keys { - res[k] = v[i] + for i, k := range allKeys { + if k != defaultSessKey { + res[k] = vals[i] + } } return res, err @@ -134,210 +128,284 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err // GetAll gets all fields from hashmap. func (s *Store) GetAll(id string) (map[string]interface{}, error) { - if !validateID(id) { - return nil, ErrInvalidSession + vals, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result() + if err != nil { + return nil, err } - conn := s.pool.Get() - defer conn.Close() + // Convert results to type `map[string]interface{}` + out := make(map[string]interface{}) + for k, v := range vals { + if k != defaultSessKey { + out[k] = v + } + } - return s.interfaceMap(conn.Do("HGETALL", s.prefix+id)) + return out, nil } -// Set sets a value to given session but stored only on commit +// Set sets a value to given session. +// If session is not present in backend then its still written. func (s *Store) Set(id, key string, val interface{}) error { - if !validateID(id) { - return ErrInvalidSession - } - - s.mu.Lock() - defer s.mu.Unlock() + p := s.client.TxPipeline() + p.HSet(s.clientCtx, s.prefix+id, key, val) + p.HSet(s.clientCtx, s.prefix+id, defaultSessKey, "1") - // Create session map if doesn't exist - if _, ok := s.tempSetMap[id]; !ok { - s.tempSetMap[id] = make(map[string]interface{}) + // Set expiry of key only if 'ttl' is set, this is to + // ensure that the key remains valid indefinitely like + // how redis handles it by default + if s.ttl > 0 && s.extendTTL { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - // set value to map - s.tempSetMap[id][key] = val - - return nil + _, err := p.Exec(s.clientCtx) + return err } -// Commit sets all set values -func (s *Store) Commit(id string) error { - if !validateID(id) { - return ErrInvalidSession - } - - s.mu.RLock() - vals, ok := s.tempSetMap[id] - if !ok { - // Nothing to commit - s.mu.RUnlock() - return nil - } - +// Set sets a value to given session. +func (s *Store) SetMulti(id string, data map[string]interface{}) error { // Make slice of arguments to be passed in HGETALL command - args := make([]interface{}, len(vals)*2+1, len(vals)*2+1) - args[0] = s.prefix + id - - c := 1 - for k, v := range s.tempSetMap[id] { - args[c] = k - args[c+1] = v - c += 2 + args := []interface{}{defaultSessKey, "1"} + for k, v := range data { + args = append(args, k, v) } - s.mu.RUnlock() - - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() - - // Set to redis - conn := s.pool.Get() - defer conn.Close() - - conn.Send("MULTI") - conn.Send("HMSET", args...) + p := s.client.TxPipeline() + p.HMSet(s.clientCtx, s.prefix+id, args...) // Set expiry of key only if 'ttl' is set, this is to // ensure that the key remains valid indefinitely like // how redis handles it by default - if s.ttl > 0 { - conn.Send("EXPIRE", args[0], s.ttl.Seconds()) + if s.ttl > 0 && s.extendTTL { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - res, err := redis.Values(conn.Do("EXEC")) - if err != nil { - return err - } - - for _, r := range res { - if v, ok := r.(redis.Error); ok { - return v - } - } - - return nil + _, err := p.Exec(s.clientCtx) + return err } // Delete deletes a key from redis session hashmap. -func (s *Store) Delete(id string, key string) error { - if !validateID(id) { - return ErrInvalidSession - } - - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() - - conn := s.pool.Get() - defer conn.Close() - - _, err := conn.Do("HDEL", s.prefix+id, key) - return err +func (s *Store) Delete(id string, keys ...string) error { + return s.client.HDel(s.clientCtx, s.prefix+id, keys...).Err() } // Clear clears session in redis. func (s *Store) Clear(id string) error { - if !validateID(id) { - return ErrInvalidSession + p := s.client.TxPipeline() + p.Del(s.clientCtx, s.prefix+id).Err() + p.HSet(s.clientCtx, s.prefix+id, defaultSessKey, "1") + if s.ttl > 0 { + p.Expire(s.clientCtx, s.prefix+id, s.ttl) } - - conn := s.pool.Get() - defer conn.Close() - - _, err := conn.Do("DEL", s.prefix+id) + _, err := p.Exec(s.clientCtx) return err } -// interfaceMap is a helper method which converts HGETALL reply to map of string interface -func (s *Store) interfaceMap(result interface{}, err error) (map[string]interface{}, error) { - values, err := redis.Values(result, err) - if err != nil { - return nil, err - } +// Destroy deletes the entire session from backend. +func (s *Store) Destroy(id string) error { + return s.client.Del(s.clientCtx, s.prefix+id).Err() +} - if len(values)%2 != 0 { - return nil, errors.New("redigo: StringMap expects even number of values result") +// Int converts interface to integer. +func (s *Store) Int(r interface{}, err error) (int, error) { + if err != nil { + return 0, err } - m := make(map[string]interface{}, len(values)/2) - for i := 0; i < len(values); i += 2 { - key, ok := values[i].([]byte) - if !ok { - return nil, errors.New("redigo: StringMap key not a bulk string value") + switch r := r.(type) { + case int: + return r, nil + case int64: + if x := int(r); int64(x) != r { + return 0, ErrAssertType + } else { + return x, nil } - - m[string(key)] = values[i+1] + case []byte: + if n, err := strconv.ParseInt(string(r), 10, 0); err != nil { + return 0, ErrAssertType + } else { + return int(n), nil + } + case string: + if n, err := strconv.ParseInt(r, 10, 0); err != nil { + return 0, ErrAssertType + } else { + return int(n), nil + } + case nil: + return 0, ErrNil + case error: + return 0, r } - return m, nil + return 0, ErrAssertType } -// Int returns redis reply as integer. -func (s *Store) Int(r interface{}, err error) (int, error) { - return redis.Int(r, err) -} - -// Int64 returns redis reply as Int64. +// Int64 converts interface to Int64. func (s *Store) Int64(r interface{}, err error) (int64, error) { - return redis.Int64(r, err) -} + if err != nil { + return 0, err + } -// UInt64 returns redis reply as UInt64. -func (s *Store) UInt64(r interface{}, err error) (uint64, error) { - return redis.Uint64(r, err) -} + switch r := r.(type) { + case int: + return int64(r), nil + case int64: + return r, nil + case []byte: + if n, err := strconv.ParseInt(string(r), 10, 64); err != nil { + return 0, ErrAssertType + } else { + return n, nil + } + case string: + if n, err := strconv.ParseInt(r, 10, 64); err != nil { + return 0, ErrAssertType + } else { + return n, nil + } + case nil: + return 0, ErrNil + case error: + return 0, r + } -// Float64 returns redis reply as Float64. -func (s *Store) Float64(r interface{}, err error) (float64, error) { - return redis.Float64(r, err) + return 0, ErrAssertType } -// String returns redis reply as String. -func (s *Store) String(r interface{}, err error) (string, error) { - return redis.String(r, err) -} +// UInt64 converts interface to UInt64. +func (s *Store) UInt64(r interface{}, err error) (uint64, error) { + if err != nil { + return 0, err + } -// Bytes returns redis reply as Bytes. -func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { - return redis.Bytes(r, err) -} + switch r := r.(type) { + case uint64: + return r, nil + case int: + if r < 0 { + return 0, ErrAssertType + } + return uint64(r), nil + case int64: + if r < 0 { + return 0, ErrAssertType + } + return uint64(r), nil + case []byte: + if n, err := strconv.ParseUint(string(r), 10, 64); err != nil { + return 0, ErrAssertType + } else { + return n, nil + } + case string: + if n, err := strconv.ParseUint(r, 10, 64); err != nil { + return 0, ErrAssertType + } else { + return n, nil + } + case nil: + return 0, ErrNil + case error: + return 0, r + } -// Bool returns redis reply as Bool. -func (s *Store) Bool(r interface{}, err error) (bool, error) { - return redis.Bool(r, err) + return 0, ErrAssertType } -func validateID(id string) bool { - if len(id) != sessionIDLen { - return false +// Float64 converts interface to Float64. +func (s *Store) Float64(r interface{}, err error) (float64, error) { + if err != nil { + return 0, err } - - for _, r := range id { - if !unicode.IsDigit(r) && !unicode.IsLetter(r) { - return false + switch r := r.(type) { + case float64: + return r, err + case []byte: + if n, err := strconv.ParseFloat(string(r), 64); err != nil { + return 0, ErrAssertType + } else { + return n, nil + } + case string: + if n, err := strconv.ParseFloat(r, 64); err != nil { + return 0, ErrAssertType + } else { + return n, nil } + case nil: + return 0, ErrNil + case error: + return 0, r } - - return true + return 0, ErrAssertType } -// generateID generates a random alpha-num session ID. -func generateID(n int) (string, error) { - const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - bytes := make([]byte, n) - if _, err := rand.Read(bytes); err != nil { +// String converts interface to String. +func (s *Store) String(r interface{}, err error) (string, error) { + if err != nil { return "", err } + switch r := r.(type) { + case []byte: + return string(r), nil + case string: + return r, nil + case nil: + return "", ErrNil + case error: + return "", r + } + return "", ErrAssertType +} - for k, v := range bytes { - bytes[k] = dict[v%byte(len(dict))] +// Bytes converts interface to Bytes. +func (s *Store) Bytes(r interface{}, err error) ([]byte, error) { + if err != nil { + return nil, err + } + switch r := r.(type) { + case []byte: + return r, nil + case string: + return []byte(r), nil + case nil: + return nil, ErrNil + case error: + return nil, r } + return nil, ErrAssertType +} - return string(bytes), nil +// Bool converts interface to Bool. +func (s *Store) Bool(r interface{}, err error) (bool, error) { + if err != nil { + return false, err + } + switch r := r.(type) { + case bool: + return r, err + // Very common in redis to reply int64 with 0 for bool flag. + case int: + return r != 0, nil + case int64: + return r != 0, nil + case []byte: + if n, err := strconv.ParseBool(string(r)); err != nil { + return false, ErrAssertType + } else { + return n, nil + } + case string: + if n, err := strconv.ParseBool(r); err != nil { + return false, ErrAssertType + } else { + return n, nil + } + case nil: + return false, ErrNil + case error: + return false, r + } + return false, ErrAssertType } diff --git a/stores/redis/store_test.go b/stores/redis/store_test.go index f9bcd5e..bf029b9 100644 --- a/stores/redis/store_test.go +++ b/stores/redis/store_test.go @@ -1,17 +1,19 @@ package redis import ( + "context" "errors" "testing" "time" "github.com/alicebob/miniredis/v2" - "github.com/gomodule/redigo/redis" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" ) var ( mockRedis *miniredis.Miniredis + errTest = errors.New("test error") ) func init() { @@ -22,550 +24,569 @@ func init() { } } -func getRedisPool() *redis.Pool { - return &redis.Pool{ - Wait: true, - Dial: func() (redis.Conn, error) { - c, err := redis.Dial( - "tcp", - mockRedis.Addr(), - ) - - return c, err - }, - } +func getRedisClient() redis.UniversalClient { + return redis.NewClient(&redis.Options{ + Addr: mockRedis.Addr(), + }) } func TestNew(t *testing.T) { - assert := assert.New(t) - rPool := getRedisPool() - str := New(rPool) - assert.Equal(str.prefix, defaultPrefix) - assert.Equal(str.pool, rPool) - assert.NotNil(str.tempSetMap) + client := getRedisClient() + ctx := context.Background() + str := New(ctx, client) + assert.Equal(t, str.prefix, defaultPrefix) + assert.Equal(t, str.client, client) + assert.Equal(t, str.clientCtx, ctx) } func TestSetPrefix(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + str := New(context.TODO(), getRedisClient()) str.SetPrefix("test") - assert.Equal(str.prefix, "test") + assert.Equal(t, str.prefix, "test") } func TestSetTTL(t *testing.T) { - assert := assert.New(t) testDur := time.Second * 10 - str := New(getRedisPool()) - str.SetTTL(testDur) - assert.Equal(str.ttl, testDur) + str := New(context.TODO(), getRedisClient()) + str.SetTTL(testDur, true) + assert.Equal(t, str.ttl, testDur) + assert.True(t, str.extendTTL) } func TestCreate(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - - id, err := str.Create() - assert.Nil(err) - assert.Equal(len(id), sessionIDLen) -} + var ( + id = "testid_create" + client = getRedisClient() + str = New(context.TODO(), client) + ) + str.SetTTL(time.Second*100, false) + err := str.Create(id) + assert.Nil(t, err) -func TestGetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + vals, err := client.HGetAll(context.TODO(), str.prefix+id).Result() + assert.NoError(t, err) + assert.Contains(t, vals, defaultSessKey) - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + ttl, _ := client.TTL(context.TODO(), str.prefix+id).Result() + assert.Equal(t, ttl, time.Second*100) } func TestGet(t *testing.T) { - assert := assert.New(t) - key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somekey" - value := 100 - redisPool := getRedisPool() - - // Set a key - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HSET", defaultPrefix+key, field, value) - assert.NoError(err) - - str := New(redisPool) + var ( + id = "testid_get" + field = "somekey" + value = 100 + client = getRedisClient() + str = New(context.TODO(), client) + ) + // Invalid session. + val, err := str.Get("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) - val, err := redis.Int(str.Get(key, field)) - assert.NoError(err) - assert.Equal(val, value) -} + // Check valid session. + err = client.HMSet(context.TODO(), str.prefix+id, field, value, defaultSessKey, "1").Err() + assert.NoError(t, err) -func TestGetFieldNotFoundError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + val, err = str.Int(str.Get(id, field)) + assert.NoError(t, err) + assert.Equal(t, val, value) - key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(key, "invalidkey") - assert.Nil(val) - assert.Error(err, ErrFieldNotFound.Error()) + // Check for invalid key. + _, err = str.Int(str.Get(id, "invalidfield")) + assert.ErrorIs(t, ErrNil, err) } -func TestGetMultiInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - +func TestGetMulti(t *testing.T) { + var ( + id = "testid_getmulti" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + invalidField = "foo" + client = getRedisClient() + str = New(context.TODO(), client) + ) + // Invalid session. val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) -} + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) -func TestGetMultiFieldEmptySession(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - - key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" - field := "somefield" - _, err := str.GetMulti(key, field) - assert.Nil(err) -} + // Set a key + err = client.HMSet(context.TODO(), str.prefix+id, defaultSessKey, "1", field1, value1, field2, value2).Err() + assert.NoError(t, err) -func TestGetMulti(t *testing.T) { - assert := assert.New(t) - key := "5dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - redisPool := getRedisPool() + vals, err := str.GetMulti(id, field1, field2, invalidField) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.Contains(t, vals, invalidField) - // Set a key - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2, field3, value3) - assert.NoError(err) - - str := New(redisPool) - - vals, err := str.GetMulti(key, field1, field2) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.NotContains(vals, field3) - - val1, err := redis.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) - - val2, err := redis.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) -} + val1, err := str.Int(vals[field1], nil) + assert.NoError(t, err) + assert.Equal(t, val1, value1) -func TestGetAllInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + val2, err := str.String(vals[field2], nil) + assert.NoError(t, err) + assert.Equal(t, val2, value2) - val, err := str.GetAll("invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + // Check for invalid key. + _, err = str.String(vals[invalidField], nil) + assert.ErrorIs(t, ErrNil, err) } func TestGetAll(t *testing.T) { - assert := assert.New(t) - key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - redisPool := getRedisPool() + var ( + key = "testid_getall" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + client = getRedisClient() + str = New(context.TODO(), client) + ) // Set a key - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2, field3, value3) - assert.NoError(err) - - str := New(redisPool) + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1, field2, value2).Err() + assert.NoError(t, err) vals, err := str.GetAll(key) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.Contains(vals, field3) - - val1, err := redis.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) - - val2, err := redis.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) - - val3, err := redis.Float64(vals[field3], nil) - assert.NoError(err) - assert.Equal(val3, value3) -} + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) -func TestSetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + val1, err := str.Int(vals[field1], nil) + assert.NoError(t, err) + assert.Equal(t, val1, value1) - err := str.Set("invalidid", "key", "value") - assert.Error(err, ErrInvalidSession.Error()) + val2, err := str.String(vals[field2], nil) + assert.NoError(t, err) + assert.Equal(t, val2, value2) } func TestSet(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - // this key is unique across all tests - key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" - field := "somekey" - value := 100 - - assert.NotNil(str.tempSetMap) - assert.NotContains(str.tempSetMap, key) + var ( + client = getRedisClient() + str = New(context.TODO(), client) + ttl = time.Second * 10 + // this key is unique across all tests + key = "testid_set" + field = "somekey" + value = 100 + ) + str.SetTTL(ttl, true) err := str.Set(key, field, value) - assert.NoError(err) - assert.Contains(str.tempSetMap, key) - assert.Contains(str.tempSetMap[key], field) - assert.Equal(str.tempSetMap[key][field], value) + assert.NoError(t, err) // Check ifs not commited to redis - conn := redisPool.Get() - defer conn.Close() - val, err := conn.Do("TTL", defaultPrefix+key) - assert.NoError(err) - // -2 represents key doesn't exist - assert.Equal(val, int64(-2)) -} + v1, err := client.Exists(context.TODO(), str.prefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), v1) -func TestCommitInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + v2, err := str.Int(client.HGet(context.TODO(), str.prefix+key, field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, v2) - err := str.Commit("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) + dur, err := client.TTL(context.TODO(), str.prefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, dur, ttl) } -func TestEmptyCommit(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) - - err := str.Commit("15IHy6S2uBuKaNnTUszB2180898ikGY1") - assert.NoError(err) -} - -func TestCommit(t *testing.T) { - // Test should commit in redis with expiry on key - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - str.SetTTL(10 * time.Second) +func TestSetMulti(t *testing.T) { + // Test should only set in internal map and not in redis + var ( + client = getRedisClient() + str = New(context.TODO(), client) + ttl = time.Second * 10 + key = "testid_setmulti" + field1 = "somekey1" + value1 = 100 + field2 = "somekey2" + value2 = "somevalue" + ) + str.SetTTL(ttl, true) + + err := str.SetMulti(key, map[string]interface{}{ + field1: value1, + field2: value2, + }) + assert.NoError(t, err) - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" + // Check ifs not commited to redis + v1, err := client.Exists(context.TODO(), str.prefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), v1) - err := str.Set(key, field1, value1) - assert.NoError(err) + v2, err := str.Int(client.HGet(context.TODO(), str.prefix+key, field1).Result()) + assert.NoError(t, err) + assert.Equal(t, value1, v2) - err = str.Set(key, field2, value2) - assert.NoError(err) + dur, err := client.TTL(context.TODO(), str.prefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, dur, ttl) +} - err = str.Commit(key) - assert.NoError(err) +func TestDelete(t *testing.T) { + // Test should only set in internal map and not in redis + var ( + client = getRedisClient() + str = New(context.TODO(), client) - conn := redisPool.Get() - defer conn.Close() - vals, err := redis.Values(conn.Do("HGETALL", defaultPrefix+key)) - assert.Equal(2*2, len(vals)) + // this key is unique across all tests + key = "testid_delete" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + ) - ttl, err := redis.Int(conn.Do("TTL", defaultPrefix+key)) - assert.NoError(err) + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1, field2, value2).Err() + assert.NoError(t, err) - assert.Equal(true, ttl > 0 && ttl <= 10) -} + err = str.Delete(key, field1) + assert.NoError(t, err) -func TestDeleteInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + val, err := client.HExists(context.TODO(), str.prefix+key, field1).Result() + assert.False(t, val) + assert.NoError(t, err) - err := str.Delete("invalidkey", "somefield") - assert.Error(err, ErrInvalidSession.Error()) + val, err = client.HExists(context.TODO(), str.prefix+key, field2).Result() + assert.True(t, val) + assert.NoError(t, err) } -func TestDelete(t *testing.T) { +func TestClear(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) - assert.NoError(err) + var ( + client = getRedisClient() + str = New(context.TODO(), client) - err = str.Delete(key, field1) - assert.NoError(err) + // this key is unique across all tests + key = "testid_clear" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + ) - val, err := redis.Bool(conn.Do("HEXISTS", defaultPrefix+key, field1)) - assert.False(val) + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1, field2, value2).Err() + assert.NoError(t, err) - val, err = redis.Bool(conn.Do("HEXISTS", defaultPrefix+key, field2)) - assert.True(val) -} + err = str.Clear(key) + assert.NoError(t, err) -func TestClearInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(getRedisPool()) + val, err := client.HExists(context.TODO(), str.prefix+key, defaultSessKey).Result() + assert.NoError(t, err) + assert.True(t, val) - err := str.Clear("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) + val, err = client.HExists(context.TODO(), str.prefix+key, field1).Result() + assert.NoError(t, err) + assert.False(t, val) } -func TestClear(t *testing.T) { +func TestDestroy(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) - assert.NoError(err) - - // Check if its set - val, err := conn.Do("TTL", defaultPrefix+key) - assert.NoError(err) - // -2 represents key doesn't exist - assert.NotEqual(val, int64(-2)) + var ( + client = getRedisClient() + str = New(context.TODO(), client) - err = str.Clear(key) - assert.NoError(err) + // this key is unique across all tests + key = "testid_clear" + field1 = "somekey" + value1 = 100 + ) - val, err = conn.Do("TTL", defaultPrefix+key) - assert.NoError(err) - // -2 represents key doesn't exist - assert.Equal(val, int64(-2)) -} + err := client.HMSet(context.TODO(), str.prefix+key, defaultSessKey, "1", field1, value1).Err() + assert.NoError(t, err) -func TestInterfaceMap(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("HMSET", defaultPrefix+key, field1, value1, field2, value2) - assert.NoError(err) - - vals, err := str.interfaceMap(conn.Do("HGETALL", defaultPrefix+key)) - assert.Contains(vals, field1) - assert.Contains(vals, field2) -} + err = str.Destroy(key) + assert.NoError(t, err) -func TestInterfaceMapWithError(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) - - testError := errors.New("test error") - vals, err := str.interfaceMap(nil, testError) - assert.Nil(vals) - assert.Error(err, testError.Error()) - - valsInfSlice := []interface{}{nil, nil, nil} - vals, err = str.interfaceMap(valsInfSlice, nil) - assert.Nil(vals) - assert.Equal(err.Error(), "redigo: StringMap expects even number of values result") - - valsInfSlice = []interface{}{"abc123", 123} - vals, err = str.interfaceMap(valsInfSlice, nil) - assert.Nil(vals) - assert.Equal(err.Error(), "redigo: StringMap key not a bulk string value") + val, err := client.Exists(context.TODO(), str.prefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, val, int64(0)) } func TestInt(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) + + v, err := str.Int(1, nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) + + v, err = str.Int("1", nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) - field := "somekey" - value := 100 + v, err = str.Int([]byte("1"), nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + var tVal int64 = 1 + v, err = str.Int(tVal, nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) - val, err := str.Int(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + var tVal1 interface{} = 1 + v, err = str.Int(tVal1, nil) + assert.NoError(t, err) + assert.Equal(t, 1, v) - testError := errors.New("test error") - val, err = str.Int(value, testError) - assert.Error(err, testError.Error()) + // Test if ErrNil is returned if value is nil. + v, err = str.Int(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, 0, v) + + // Test if custom error sent is returned. + v, err = str.Int(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, 0, v) + + // Test invalid assert error. + v, err = str.Int(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, 0, v) } func TestInt64(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) + + v, err := str.Int64(int64(1), nil) + assert.NoError(t, err) + assert.Equal(t, int64(1), v) + + v, err = str.Int64("1", nil) + assert.NoError(t, err) + assert.Equal(t, int64(1), v) + + v, err = str.Int64([]byte("1"), nil) + assert.NoError(t, err) + assert.Equal(t, int64(1), v) - field := "somekey" - var value int64 = 100 + var tVal interface{} = 1 + v, err = str.Int64(tVal, nil) + assert.NoError(t, err) + assert.Equal(t, int64(1), v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + // Test if ErrNil is returned if value is nil. + v, err = str.Int64(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, int64(0), v) - val, err := str.Int64(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + // Test if custom error sent is returned. + v, err = str.Int64(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, int64(0), v) - testError := errors.New("test error") - val, err = str.Int64(value, testError) - assert.Error(err, testError.Error()) + // Test invalid assert error. + v, err = str.Int64(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, int64(0), v) } func TestUInt64(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) - field := "somekey" - var value uint64 = 100 + v, err := str.UInt64(uint64(1), nil) + assert.NoError(t, err) + assert.Equal(t, uint64(1), v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + v, err = str.UInt64("1", nil) + assert.NoError(t, err) + assert.Equal(t, uint64(1), v) - val, err := str.UInt64(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + v, err = str.UInt64([]byte("1"), nil) + assert.NoError(t, err) + assert.Equal(t, uint64(1), v) - testError := errors.New("test error") - val, err = str.UInt64(value, testError) - assert.Error(err, testError.Error()) + var tVal interface{} = 1 + v, err = str.UInt64(tVal, nil) + assert.NoError(t, err) + assert.Equal(t, uint64(1), v) + + // Test if ErrNil is returned if value is nil. + v, err = str.UInt64(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, uint64(0), v) + + // Test if custom error sent is returned. + v, err = str.UInt64(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, uint64(0), v) + + // Test invalid assert error. + v, err = str.UInt64(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, uint64(0), v) } func TestFloat64(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) + + v, err := str.Float64(float64(1.11), nil) + assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - field := "somekey" - var value float64 = 100 + v, err = str.Float64("1.11", nil) + assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + v, err = str.Float64([]byte("1.11"), nil) + assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - val, err := str.Float64(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + var tVal float64 = 1.11 + v, err = str.Float64(tVal, nil) + assert.NoError(t, err) + assert.Equal(t, float64(1.11), v) - testError := errors.New("test error") - val, err = str.Float64(value, testError) - assert.Error(err, testError.Error()) + // Test if ErrNil is returned if value is nil. + v, err = str.Float64(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, float64(0), v) + + // Test if custom error sent is returned. + v, err = str.Float64(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, float64(0), v) + + // Test invalid assert error. + v, err = str.Float64("abc", nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, float64(0), v) } func TestString(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) + + v, err := str.String("abc", nil) + assert.NoError(t, err) + assert.Equal(t, "abc", v) - field := "somekey" - value := "abc123" + v, err = str.String([]byte("abc"), nil) + assert.NoError(t, err) + assert.Equal(t, "abc", v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + var tVal interface{} = "abc" + v, err = str.String(tVal, nil) + assert.NoError(t, err) + assert.Equal(t, "abc", v) - val, err := str.String(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + // Test if ErrNil is returned if value is nil. + v, err = str.String(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, "", v) - testError := errors.New("test error") - val, err = str.String(value, testError) - assert.Error(err, testError.Error()) + // Test if custom error sent is returned. + v, err = str.String(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, "", v) + + // Test invalid assert error. + v, err = str.String(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, "", v) } func TestBytes(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) + + v, err := str.Bytes("abc", nil) + assert.NoError(t, err) + assert.Equal(t, []byte("abc"), v) - field := "somekey" - var value []byte = []byte("abc123") + v, err = str.Bytes([]byte("abc"), nil) + assert.NoError(t, err) + assert.Equal(t, []byte("abc"), v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + var tVal interface{} = "abc" + v, err = str.Bytes(tVal, nil) + assert.NoError(t, err) + assert.Equal(t, []byte("abc"), v) - val, err := str.Bytes(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + // Test if ErrNil is returned if value is nil. + v, err = str.Bytes(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, []byte(nil), v) - testError := errors.New("test error") - val, err = str.Bytes(value, testError) - assert.Error(err, testError.Error()) + // Test if custom error sent is returned. + v, err = str.Bytes(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, []byte(nil), v) + + // Test invalid assert error. + v, err = str.Bytes(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, []byte(nil), v) } func TestBool(t *testing.T) { - assert := assert.New(t) - redisPool := getRedisPool() - str := New(redisPool) + str := New(context.TODO(), nil) + + v, err := str.Bool(true, nil) + assert.NoError(t, err) + assert.Equal(t, true, v) + + v, err = str.Bool(false, nil) + assert.NoError(t, err) + assert.Equal(t, false, v) + + v, err = str.Bool(0, nil) + assert.NoError(t, err) + assert.Equal(t, false, v) - field := "somekey" - value := true + v, err = str.Bool(1, nil) + assert.NoError(t, err) + assert.Equal(t, true, v) - conn := redisPool.Get() - defer conn.Close() - _, err := conn.Do("SET", field, value) - assert.NoError(err) + v, err = str.Bool(int64(0), nil) + assert.NoError(t, err) + assert.Equal(t, false, v) - val, err := str.Bool(conn.Do("GET", field)) - assert.NoError(err) - assert.Equal(value, val) + v, err = str.Bool(int64(1), nil) + assert.NoError(t, err) + assert.Equal(t, true, v) - testError := errors.New("test error") - val, err = str.Bool(value, testError) - assert.Error(err, testError.Error()) + v, err = str.Bool([]byte("true"), nil) + assert.NoError(t, err) + assert.Equal(t, true, v) + + v, err = str.Bool([]byte("false"), nil) + assert.NoError(t, err) + assert.Equal(t, false, v) + + v, err = str.Bool("true", nil) + assert.NoError(t, err) + assert.Equal(t, true, v) + + v, err = str.Bool("false", nil) + assert.NoError(t, err) + assert.Equal(t, false, v) + + // Test if ErrNil is returned if value is nil. + v, err = str.Bool(nil, nil) + assert.ErrorIs(t, err, ErrNil) + assert.Equal(t, false, v) + + // Test if custom error sent is returned. + v, err = str.Bool(nil, errTest) + assert.ErrorIs(t, err, errTest) + assert.Equal(t, false, v) + + // Test invalid assert error. + v, err = str.Bool(10.1112, nil) + assert.ErrorIs(t, err, ErrAssertType) + assert.Equal(t, false, v) +} + +func TestError(t *testing.T) { + err := Err{ + code: 1, + msg: "test", + } + assert.Equal(t, 1, err.Code()) + assert.Equal(t, "test", err.Error()) } diff --git a/stores/securecookie/go.mod b/stores/securecookie/go.mod index fbff51e..413de4f 100644 --- a/stores/securecookie/go.mod +++ b/stores/securecookie/go.mod @@ -1,8 +1,14 @@ -module github.com/vividvilla/simplesessions/stores/securecookie/v2 +module github.com/vividvilla/simplesessions/stores/securecookie/v3 -go 1.14 +go 1.18 require ( github.com/gorilla/securecookie v1.1.2 github.com/stretchr/testify v1.9.0 ) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/stores/securecookie/secure_cookie.go b/stores/securecookie/secure_cookie.go index 4016f94..ed50940 100644 --- a/stores/securecookie/secure_cookie.go +++ b/stores/securecookie/secure_cookie.go @@ -1,7 +1,7 @@ package securecookie import ( - "errors" + "fmt" "sync" "github.com/gorilla/securecookie" @@ -15,9 +15,8 @@ var ( // Error codes for store errors. This should match the codes // defined in the /simplesessions package exactly. ErrInvalidSession = &Err{code: 1, msg: "invalid session"} - ErrFieldNotFound = &Err{code: 2, msg: "field not found"} - ErrAssertType = &Err{code: 3, msg: "assertion failed"} - ErrNil = &Err{code: 4, msg: "nil returned"} + ErrAssertType = &Err{code: 2, msg: "assertion failed"} + ErrNil = &Err{code: 3, msg: "nil returned"} ) type Err struct { @@ -76,18 +75,20 @@ func (s *Store) SetCookieName(cookieName string) { } // IsValid checks if the given cookie value is valid. -func (s *Store) IsValid(cv string) (bool, error) { +func (s *Store) IsValid(cv string) bool { if _, err := s.decode(cv); err != nil { - return false, nil + return false } - - return true, nil + return true } // Create creates a new secure cookie session with empty map. -func (s *Store) Create() (string, error) { - // Create empty cookie - return s.encode(make(map[string]interface{})) +// Once called, Flush() should be called to retrieve the updated. +func (s *Store) Create(id string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.tempSetMap[id] = make(map[string]interface{}) + return nil } // Get returns a field value from session @@ -101,7 +102,7 @@ func (s *Store) Get(cv, key string) (interface{}, error) { // Get given field val, ok := vals[key] if !ok { - return nil, ErrFieldNotFound + return nil, nil } return val, nil @@ -117,9 +118,15 @@ func (s *Store) GetMulti(cv string, keys ...string) (map[string]interface{}, err } // Get all given fields - res := make(map[string]interface{}) + var ( + ok bool + res = make(map[string]interface{}) + ) for _, k := range keys { - res[k], _ = vals[k] + res[k], ok = vals[k] + if !ok { + res[k] = nil + } } return res, nil @@ -136,6 +143,8 @@ func (s *Store) GetAll(cv string) (map[string]interface{}, error) { } // Set sets a field in session but not saved untill commit is called. +// Flush() should be called to retrieve the updated, unflushed values +// and written to the cookie externally. func (s *Store) Set(cv, key string, val interface{}) error { s.mu.Lock() defer s.mu.Unlock() @@ -151,20 +160,37 @@ func (s *Store) Set(cv, key string, val interface{}) error { return nil } -// Commit is unsupported in this store. -func (s *Store) Commit(cv string) error { - return errors.New("Commit() is not supported. Use Flush() to get values and write to cookie externally.") +// SetMulti sets given map of kv pairs to session. Flush() should be +// called to retrieve the updated, unflushed values and written to the cookie +// externally. +func (s *Store) SetMulti(cv string, vals map[string]interface{}) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Create session map if doesn't exist + if _, ok := s.tempSetMap[cv]; !ok { + s.tempSetMap[cv] = make(map[string]interface{}) + } + + for k, v := range vals { + s.tempSetMap[cv][k] = v + } + + return nil } // Flush flushes the 'set' buffer and returns encoded secure cookie value ready to be saved. // This value should be written to the cookie externally. +// This can be used with simplessions.Session.WriteCookie. +// val, _ := str.Flush(cookieVal) +// sess.WriteCookie(val) func (s *Store) Flush(cv string) (string, error) { s.mu.Lock() defer s.mu.Unlock() vals, ok := s.tempSetMap[cv] if !ok { - return "", nil + return "", fmt.Errorf("nothing to flush") } delete(s.tempSetMap, cv) @@ -176,15 +202,17 @@ func (s *Store) Flush(cv string) (string, error) { // Delete deletes a field from session. Once called, Flush() should be // called to retrieve the updated, unflushed values and written to the cookie // externally. -func (s *Store) Delete(cv, key string) error { +func (s *Store) Delete(cv string, keys ...string) error { // Decode current cookie vals, err := s.decode(cv) if err != nil { return ErrInvalidSession } - // Delete given key in current values. - delete(vals, key) + for _, k := range keys { + // Delete given key in current values. + delete(vals, k) + } // Create session map if doesn't exist. s.mu.Lock() @@ -202,9 +230,21 @@ func (s *Store) Delete(cv, key string) error { return nil } -// Clear clears the session. +// Clear clears the session. Once called, Flush() should be +// called to retrieve the updated, unflushed values and written to the cookie +// externally. func (s *Store) Clear(cv string) error { - return errors.New("Clear() is not supported. Use Create() to create an empty map and write to cookie externally.") + s.mu.Lock() + defer s.mu.Unlock() + s.tempSetMap[cv] = make(map[string]interface{}) + return nil +} + +// Destroy clears the session. Once called, Flush() should be +// called to retrieve the updated, unflushed values and written to the cookie +// externally. +func (s *Store) Destroy(cv string) error { + return s.Clear(cv) } // Int is a helper method to type assert as integer diff --git a/stores/securecookie/secure_cookie_test.go b/stores/securecookie/secure_cookie_test.go index d61101a..b0ec709 100644 --- a/stores/securecookie/secure_cookie_test.go +++ b/stores/securecookie/secure_cookie_test.go @@ -13,379 +13,346 @@ var ( ) func TestNew(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) - assert.NotNil(str.sc) - assert.NotNil(str.tempSetMap) + assert.NotNil(t, str.sc) + assert.NotNil(t, str.tempSetMap) } func TestSetCookieName(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) - - assert.Equal(defaultCookieName, str.cookieName) + assert.Equal(t, defaultCookieName, str.cookieName) str.SetCookieName("csrftoken") - assert.Equal("csrftoken", str.cookieName) + assert.Equal(t, "csrftoken", str.cookieName) } func TestIsValid(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) - - assert.False(str.IsValid("")) + assert.False(t, str.IsValid("")) encoded, err := str.encode(make(map[string]interface{})) - assert.Nil(err) - assert.True(str.IsValid(encoded)) + assert.Nil(t, err) + assert.True(t, str.IsValid(encoded)) } func TestCreate(t *testing.T) { - assert := assert.New(t) - str := New(secretKey, blockKey) - - id, err := str.Create() - assert.Nil(err) - assert.True(str.IsValid(id)) -} - -func TestGetInvalidSessionError(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + err := str.Create("testid") + assert.Nil(t, err) + assert.Contains(t, str.tempSetMap, "testid") + assert.Equal(t, 0, len(str.tempSetMap["testid"])) } func TestGet(t *testing.T) { - assert := assert.New(t) - field := "somekey" - value := 100 - - // Set a key str := New(secretKey, blockKey) - - m := make(map[string]interface{}) - m[field] = value + val, err := str.Get("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) + + var ( + field = "somekey" + value = 100 + m = map[string]interface{}{ + field: value, + } + ) cv, err := str.encode(m) - assert.Nil(err) - - val, err := str.Get(cv, field) - assert.NoError(err) - assert.Equal(val, value) -} - -func TestGetFieldNotFoundError(t *testing.T) { - assert := assert.New(t) - field := "someotherkey" + assert.Nil(t, err) - // Set a key - str := New(secretKey, blockKey) + val, err = str.Get(cv, field) + assert.NoError(t, err) + assert.Equal(t, val, value) - m := make(map[string]interface{}) - cv, err := str.encode(m) - assert.Nil(err) - - _, err = str.Get(cv, field) - assert.Error(ErrFieldNotFound) + val, err = str.Get(cv, "invalid") + assert.NoError(t, err) + assert.Equal(t, nil, val) } -func TestGetMultiInvalidSessionError(t *testing.T) { - assert := assert.New(t) +func TestGetMulti(t *testing.T) { str := New(secretKey, blockKey) - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestGetMultiFieldEmptySession(t *testing.T) { - assert := assert.New(t) - str := New(secretKey, blockKey) + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) - m := make(map[string]interface{}) - cv, err := str.encode(m) - assert.Nil(err) - - _, err = str.GetMulti(cv) - assert.Nil(err) -} - -func TestGetMulti(t *testing.T) { - assert := assert.New(t) - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - - str := New(secretKey, blockKey) + var ( + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + field3 = "thishouldntbethere" + ) // Set a key - m := make(map[string]interface{}) - m[field1] = value1 - m[field2] = value2 - m[field3] = value3 + m := map[string]interface{}{ + field1: value1, + field2: value2, + } cv, err := str.encode(m) - assert.Nil(err) - - vals, err := str.GetMulti(cv, field1, field2) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.NotContains(vals, field3) - - assert.NoError(err) - assert.Equal(vals[field1], value1) - - assert.NoError(err) - assert.Equal(vals[field2], value2) + assert.Nil(t, err) + + vals, err := str.GetMulti(cv, field1, field2, field3) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.Contains(t, vals, field3) + assert.Equal(t, vals[field1], value1) + assert.Equal(t, vals[field2], value2) + assert.Equal(t, vals[field3], nil) } -func TestGetAllInvalidSessionError(t *testing.T) { - assert := assert.New(t) +func TestGetAll(t *testing.T) { str := New(secretKey, blockKey) val, err := str.GetAll("invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) -} + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) -func TestGetAll(t *testing.T) { - assert := assert.New(t) - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - str := New(secretKey, blockKey) + var ( + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + ) // Set a key - m := make(map[string]interface{}) - m[field1] = value1 - m[field2] = value2 + m := map[string]interface{}{ + field1: value1, + field2: value2, + } cv, err := str.encode(m) - assert.Nil(err) + assert.Nil(t, err) vals, err := str.GetAll(cv) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - - assert.NoError(err) - assert.Equal(vals[field1], value1) - - assert.NoError(err) - assert.Equal(vals[field2], value2) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.Equal(t, vals[field1], value1) + assert.Equal(t, vals[field2], value2) } func TestSet(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) - str := New(secretKey, blockKey) - - // this key is unique across all tests - field := "somekey" - value := 100 - - m := make(map[string]interface{}) + var ( + str = New(secretKey, blockKey) + + // this key is unique across all tests + field = "somekey" + value = 100 + ) + m := map[string]interface{}{ + field: value, + } cv, err := str.encode(m) - assert.Nil(err) + assert.Nil(t, err) err = str.Set(cv, field, value) - assert.NoError(err) - assert.Contains(str.tempSetMap, cv) - assert.Contains(str.tempSetMap[cv], field) - assert.Equal(str.tempSetMap[cv][field], value) + assert.NoError(t, err) + assert.Contains(t, str.tempSetMap, cv) + assert.Contains(t, str.tempSetMap[cv], field) + assert.Equal(t, str.tempSetMap[cv][field], value) } -func TestEmptyCommit(t *testing.T) { - assert := assert.New(t) - str := New(secretKey, blockKey) - - m := make(map[string]interface{}) +func TestSetMulti(t *testing.T) { + // Test should only set in internal map and not in redis + var ( + str = New(secretKey, blockKey) + + // this key is unique across all tests + field1 = "somekey1" + value1 = 100 + field2 = "somekey2" + value2 = 10 + ) + m := map[string]interface{}{ + field1: value1, + field2: value2, + } cv, err := str.encode(m) - assert.Nil(err) - - v, err := str.Flush(cv) - assert.Empty(v) - assert.NoError(err) + assert.Nil(t, err) + + err = str.SetMulti(cv, m) + assert.NoError(t, err) + assert.Contains(t, str.tempSetMap, cv) + assert.Contains(t, str.tempSetMap[cv], field1) + assert.Equal(t, str.tempSetMap[cv][field1], value1) + assert.Contains(t, str.tempSetMap[cv], field2) + assert.Equal(t, str.tempSetMap[cv][field2], value2) } -func TestCommit(t *testing.T) { - assert := assert.New(t) +func TestDelete(t *testing.T) { str := New(secretKey, blockKey) - // this key is unique across all tests - field := "somekey" - value := 100 + err := str.Delete("invalidkey", "somekey") + assert.ErrorIs(t, err, ErrInvalidSession) - m := make(map[string]interface{}) + m := map[string]interface{}{ + "key1": "val1", + "key2": "val2", + } cv, err := str.encode(m) - assert.Nil(err) - - err = str.Set(cv, field, value) - assert.NoError(err) - assert.Equal(len(str.tempSetMap), 1) - - v, err := str.Flush(cv) - assert.NotEmpty(v) - assert.NoError(err) - assert.Equal(len(str.tempSetMap), 0) - - decoded, err := str.decode(v) - assert.NoError(err) - assert.Contains(decoded, field) - assert.Equal(decoded[field], value) + assert.Nil(t, err) + assert.NoError(t, str.Delete(cv, "key1")) + assert.NotContains(t, str.tempSetMap[cv], "key1") + assert.Contains(t, str.tempSetMap[cv], "key2") } -func TestDeleteInvalidSessionError(t *testing.T) { - assert := assert.New(t) +func TestClear(t *testing.T) { str := New(secretKey, blockKey) - - err := str.Delete("invalidkey", "somekey") - assert.Error(err, ErrInvalidSession.Error()) + err := str.Clear("xxx") + assert.Nil(t, err) + assert.Equal(t, len(str.tempSetMap["xxx"]), 0) } -func TestDelete(t *testing.T) { - assert := assert.New(t) +func TestDestroy(t *testing.T) { str := New(secretKey, blockKey) + err := str.Destroy("xxx") + assert.Nil(t, err) + assert.Equal(t, len(str.tempSetMap["xxx"]), 0) +} - m := make(map[string]interface{}) - m["key1"] = "val1" - m["key2"] = "val2" - cv, err := str.encode(m) - assert.Nil(err) - - assert.NoError(str.Delete(cv, "key1")) - - v, err := str.Flush(cv) - assert.NoError(err) - - decoded, err := str.decode(v) - assert.NoError(err) - assert.NotContains(decoded, "key1") +func TestFlush(t *testing.T) { + str := New(secretKey, blockKey) + m := map[string]interface{}{ + "key1": "val1", + "key2": "val2", + } + + str.tempSetMap["id"] = m + cv, err := str.Flush("id") + assert.Nil(t, err) + + vals, err := str.decode(cv) + assert.Nil(t, err) + assert.NotContains(t, str.tempSetMap, cv) + assert.Contains(t, vals, "key1") + assert.Contains(t, vals, "key2") + assert.Equal(t, vals["key1"], "val1") + assert.Equal(t, vals["key2"], "val2") + + _, err = str.Flush("xxx") + assert.Equal(t, err.Error(), "nothing to flush") } func TestInt(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want int = 10 v, err := str.Int(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") v, err = str.Int(want, testError) - assert.Equal(v, 0) - assert.Error(testError) + assert.Equal(t, v, 0) + assert.ErrorIs(t, err, testError) _, err = str.Int("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) } func TestInt64(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want int64 = 10 v, err := str.Int64(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") - v, err = str.Int64(want, testError) - assert.Error(testError) + _, err = str.Int64(want, testError) + assert.ErrorIs(t, err, testError) _, err = str.Int64("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) } func TestUInt64(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want uint64 = 10 v, err := str.UInt64(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") - v, err = str.UInt64(want, testError) - assert.Error(testError) + _, err = str.UInt64(want, testError) + assert.ErrorIs(t, err, testError) _, err = str.UInt64("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) } func TestFloat64(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want float64 = 10 v, err := str.Float64(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") - v, err = str.Float64(want, testError) - assert.Error(testError) + _, err = str.Float64(want, testError) + assert.ErrorIs(t, err, testError) _, err = str.Float64("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) } func TestString(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want = "string" v, err := str.String(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") - v, err = str.String(want, testError) - assert.Error(testError) + _, err = str.String(want, testError) + assert.ErrorIs(t, err, testError) _, err = str.String(123, nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) } func TestBytes(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want = []byte("a") v, err := str.Bytes(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") - v, err = str.Bytes(want, testError) - assert.Error(testError) + _, err = str.Bytes(want, testError) + assert.ErrorIs(t, err, testError) _, err = str.Bytes("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) } func TestBool(t *testing.T) { - assert := assert.New(t) str := New(secretKey, blockKey) var want = true v, err := str.Bool(want, nil) - assert.Nil(err) - assert.Equal(v, want) + assert.Nil(t, err) + assert.Equal(t, v, want) testError := errors.New("test error") - v, err = str.Bool(want, testError) - assert.Error(testError) + _, err = str.Bool(want, testError) + assert.ErrorIs(t, err, testError) _, err = str.Bool("string", nil) - assert.Error(ErrAssertType) + assert.ErrorIs(t, err, ErrAssertType) +} + +func TestError(t *testing.T) { + err := Err{ + code: 1, + msg: "test", + } + assert.Equal(t, 1, err.Code()) + assert.Equal(t, "test", err.Error()) }