Skip to content

Commit

Permalink
refactor: postgres package to v3 spec
Browse files Browse the repository at this point in the history
  • Loading branch information
vividvilla committed May 30, 2024
1 parent 194668d commit f8611d6
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 176 deletions.
168 changes: 26 additions & 142 deletions stores/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@ CREATE INDEX idx_sessions ON sessions (id, created_at);
*/

import (
"crypto/rand"
"database/sql"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"unicode"

_ "github.com/lib/pq"
)
Expand All @@ -26,9 +22,8 @@ var (
// Error codes for store errors. This should match the codes
// defined in the /simplesessions package exactly.
ErrInvalidSession = &Err{code: 1, msg: "invalid session"}
ErrFieldNotFound = &Err{code: 2, msg: "field not found"}
ErrNil = &Err{code: 2, msg: "nil returned"}
ErrAssertType = &Err{code: 3, msg: "assertion failed"}
ErrNil = &Err{code: 4, msg: "nil returned"}
)

type Err struct {
Expand Down Expand Up @@ -59,11 +54,6 @@ type Store struct {
db *sql.DB
opt Opt
q *queries

commitID string
tx *sql.Tx
stmt *sql.Stmt
sync.Mutex
}

type Opt struct {
Expand All @@ -75,10 +65,6 @@ type Opt struct {
CleanInterval time.Duration `json:"clean_interval"`
}

const (
sessionIDLen = 32
)

// New creates a new Postgres store instance.
func New(opt Opt, db *sql.DB) (*Store, error) {
if opt.Table == "" {
Expand Down Expand Up @@ -107,53 +93,31 @@ func New(opt Opt, db *sql.DB) (*Store, error) {
}

// Create creates a new session and returns the ID.
func (s *Store) Create() (string, error) {
id, err := generateID(sessionIDLen)
if err != nil {
return "", err
}

if _, err := s.q.create.Exec(id); err != nil {
return "", err
}
return id, nil
func (s *Store) Create(id string) error {
_, err := s.q.create.Exec(id)
return err
}

// Get returns a single session field's value.
func (s *Store) Get(id, key string) (interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

// Scan the whole JSON map out so that it can be unmarshalled,
// preserving the types.
var b []byte
if err := s.q.get.QueryRow(id, s.opt.TTL.Seconds()).Scan(&b); err != nil {
vals, err := s.GetAll(id)
if err != nil {
if err == sql.ErrNoRows {
return nil, ErrInvalidSession
}
return nil, err
}

var mp map[string]interface{}
if err := json.Unmarshal(b, &mp); err != nil {
return nil, err
}

v, ok := mp[key]
v, ok := vals[key]
if !ok {
return nil, ErrFieldNotFound
return nil, nil
}

return v, nil
}

// GetMulti gets a map for values for multiple keys. If a key doesn't exist, it returns ErrFieldNotFound.
func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

vals, err := s.GetAll(id)
if err != nil {
return nil, err
Expand All @@ -163,7 +127,7 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err
for _, k := range keys {
v, ok := vals[k]
if !ok {
return nil, ErrFieldNotFound
return nil, nil
}
out[k] = v
}
Expand All @@ -173,10 +137,6 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err

// GetAll returns the map of all keys in the session.
func (s *Store) GetAll(id string) (map[string]interface{}, error) {
if !validateID(id) {
return nil, ErrInvalidSession
}

var b []byte
err := s.q.get.QueryRow(id, s.opt.TTL.Seconds()).Scan(&b)
if err != nil {
Expand All @@ -193,45 +153,13 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) {

// Set sets a value to given session but is stored only on commit.
func (s *Store) Set(id, key string, val interface{}) (err error) {
if !validateID(id) {
return ErrInvalidSession
}

b, err := json.Marshal(map[string]interface{}{key: val})
if err != nil {
return err
}

s.Lock()
defer func() {
if err == nil {
s.Unlock()
return
}

if s.tx != nil {
s.tx.Rollback()
s.tx = nil
}
s.stmt = nil

s.Unlock()
}()

// If a transaction isn't set, set it.
if s.tx == nil {
tx, err := s.db.Begin()
if err != nil {
return err
}

// Prepare the statement for executing SQL commands
s.tx = tx
s.stmt = tx.Stmt(s.q.update)
}

// Execute the query in the batch to be committed later.
res, err := s.stmt.Exec(id, json.RawMessage(b))
res, err := s.q.update.Exec(id, json.RawMessage(b))
if err != nil {
return err
}
Expand All @@ -245,47 +173,36 @@ func (s *Store) Set(id, key string, val interface{}) (err error) {
return ErrInvalidSession
}

s.commitID = id
return err
}

// Commit sets all set values
func (s *Store) Commit(id string) error {
if !validateID(id) {
return ErrInvalidSession
// Set sets a value to given session but is stored only on commit.
func (s *Store) SetMulti(id string, data map[string]interface{}) (err error) {
b, err := json.Marshal(data)
if err != nil {
return err
}

s.Lock()
if s.commitID != id {
s.Unlock()
return ErrInvalidSession
// Execute the query in the batch to be committed later.
res, err := s.q.update.Exec(id, json.RawMessage(b))
if err != nil {
return err
}

defer func() {
if s.stmt != nil {
s.stmt.Close()
}
s.tx = nil
s.stmt = nil
s.Unlock()
}()

if s.tx == nil {
return errors.New("nothing to commit")
num, err := res.RowsAffected()
if err != nil {
return err
}
if s.commitID != id {

// No row was updated. The session didn't exist.
if num == 0 {
return ErrInvalidSession
}

return s.tx.Commit()
return err
}

// Delete deletes a key from redis session hashmap.
func (s *Store) Delete(id string, key string) error {
if !validateID(id) {
return ErrInvalidSession
}

res, err := s.q.delete.Exec(id, key)
if err != nil {
return err
Expand All @@ -306,10 +223,6 @@ func (s *Store) Delete(id string, key string) error {

// Clear clears session in redis.
func (s *Store) Clear(id string) error {
if !validateID(id) {
return ErrInvalidSession
}

res, err := s.q.clear.Exec(id)
if err != nil {
return err
Expand Down Expand Up @@ -471,32 +384,3 @@ func (s *Store) prepareQueries() (*queries, error) {

return q, err
}

func validateID(id string) bool {
if len(id) != sessionIDLen {
return false
}

for _, r := range id {
if !unicode.IsDigit(r) && !unicode.IsLetter(r) {
return false
}
}

return true
}

// generateID generates a random alpha-num session ID.
func generateID(n int) (string, error) {
const dict = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
bytes := make([]byte, n)
if _, err := rand.Read(bytes); err != nil {
return "", err
}

for k, v := range bytes {
bytes[k] = dict[v%byte(len(dict))]
}

return string(bytes), nil
}
Loading

0 comments on commit f8611d6

Please sign in to comment.