Skip to content

Commit

Permalink
jwks tweaks, vendor deps (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia authored Jan 19, 2021
1 parent 7ba39ab commit 831ac26
Show file tree
Hide file tree
Showing 147 changed files with 24,976 additions and 936 deletions.
13 changes: 4 additions & 9 deletions internal/jwks/cache.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
package jwks

import (
"context"
"errors"
)

var (
// ErrEmptyKeyID raises when input kid is empty.
ErrEmptyKeyID = errors.New("cache: empty kid")
// ErrCacheNotFound raises when cache value not found.
// ErrCacheNotFound returned when cache value not found.
ErrCacheNotFound = errors.New("cache: value not found")
// ErrInvalidValue raises when type conversion to JWK has been failed.
ErrInvalidValue = errors.New("cache: invalid value")
)

// Cache works with cache layer.
type Cache interface {
Add(ctx context.Context, key *JWK) error
Get(ctx context.Context, kid string) (*JWK, error)
Len(ctx context.Context) (int, error)
Add(key *JWK) error
Get(kid string) (*JWK, error)
Len() (int, error)
}
49 changes: 17 additions & 32 deletions internal/jwks/cache_ttl.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package jwks

import (
"context"
"sync"
"time"
)
Expand Down Expand Up @@ -31,10 +30,11 @@ func (i *item) expired() bool {

// TTLCache is a TTL bases in-memory cache.
type TTLCache struct {
mu sync.RWMutex
ttl time.Duration
stop chan struct{}
items map[string]*item
mu sync.RWMutex
ttl time.Duration
stop chan struct{}
stopOnce sync.Once
items map[string]*item
}

// NewTTLCache returns a new instance of ttl cache.
Expand Down Expand Up @@ -78,7 +78,7 @@ func (tc *TTLCache) run() {
}

// Add item into cache.
func (tc *TTLCache) Add(_ context.Context, key *JWK) error {
func (tc *TTLCache) Add(key *JWK) error {
tc.mu.Lock()
item := &item{data: key}
item.touch(tc.ttl)
Expand All @@ -88,7 +88,7 @@ func (tc *TTLCache) Add(_ context.Context, key *JWK) error {
}

// Get item by key.
func (tc *TTLCache) Get(_ context.Context, kid string) (*JWK, error) {
func (tc *TTLCache) Get(kid string) (*JWK, error) {
tc.mu.RLock()
item, ok := tc.items[kid]
if !ok || item.expired() {
Expand All @@ -100,40 +100,25 @@ func (tc *TTLCache) Get(_ context.Context, kid string) (*JWK, error) {
return item.data, nil
}

// Remove item by key.
func (tc *TTLCache) Remove(_ context.Context, kid string) error {
// Stop stops TTL cache.
func (tc *TTLCache) Stop() error {
tc.stopOnce.Do(func() {
close(tc.stop)
})
return nil
}

func (tc *TTLCache) remove(kid string) error {
tc.mu.Lock()
delete(tc.items, kid)
tc.mu.Unlock()
return nil
}

// Contains checks item on existence.
func (tc *TTLCache) Contains(_ context.Context, kid string) (bool, error) {
tc.mu.RLock()
_, ok := tc.items[kid]
tc.mu.RUnlock()
return ok, nil
}

// Len returns current size of cache.
func (tc *TTLCache) Len(_ context.Context) (int, error) {
func (tc *TTLCache) Len() (int, error) {
tc.mu.RLock()
n := len(tc.items)
tc.mu.RUnlock()
return n, nil
}

// Purge deletes all items.
func (tc *TTLCache) Purge(_ context.Context) error {
tc.mu.Lock()
tc.items = map[string]*item{}
tc.mu.Unlock()
return nil
}

// Stop cleanup process.
func (tc *TTLCache) Stop(_ context.Context) error {
tc.stop <- struct{}{}
return nil
}
171 changes: 22 additions & 149 deletions internal/jwks/cache_ttl_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package jwks

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -29,13 +28,11 @@ func TestTTLCacheAdd(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(tc.TTL)
require.NotNil(t, cache)

for i := 0; i < tc.Ops; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
require.NoError(t, cache.Add(&JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Expand Down Expand Up @@ -78,13 +75,11 @@ func TestTTLCacheGet(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Minute)
require.NotNil(t, cache)
require.NoError(t, cache.Add(ctx, tc.Key))
require.NoError(t, cache.Add(tc.Key))

key, err := cache.Get(ctx, tc.Kid)
key, err := cache.Get(tc.Kid)
if tc.Error != nil {
require.Error(t, err)
require.ErrorIs(t, err, tc.Error)
Expand All @@ -98,178 +93,56 @@ func TestTTLCacheGet(t *testing.T) {

func TestTTLCacheRemove(t *testing.T) {
testCases := []struct {
Name string
Adds int
Dels int
Len int
Name string
NumAdd int
NumDelete int
Len int
}{
{
Name: "OK",
Adds: 75,
Dels: 50,
Len: 25,
Name: "OK",
NumAdd: 75,
NumDelete: 50,
Len: 25,
},
{
Name: "RemoveUntilEmpty",
Adds: 75,
Dels: 100,
Len: 0,
Name: "RemoveUntilEmpty",
NumAdd: 75,
NumDelete: 100,
Len: 0,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Minute)
require.NotNil(t, cache)

for i := 0; i < tc.Adds; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
for i := 0; i < tc.NumAdd; i++ {
require.NoError(t, cache.Add(&JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Use: "sig",
}))
}

for i := 0; i < tc.Dels; i++ {
for i := 0; i < tc.NumDelete; i++ {
kid := fmt.Sprintf("key-%d", i+1)
require.NoError(t, cache.Remove(ctx, kid))
require.NoError(t, cache.remove(kid))
}

n, err := cache.Len(ctx)
n, err := cache.Len()
require.NoError(t, err)
require.Equal(t, tc.Len, n)
})
}
}

func TestTTLCacheContains(t *testing.T) {
testCases := []struct {
Name string
Key *JWK
Kid string
Found bool
}{
{
Name: "OK",
Key: &JWK{
Kid: "202101",
Kty: "RSA",
Alg: "RS256",
Use: "sig",
},
Kid: "202101",
Found: true,
},
{
Name: "NotFound",
Key: &JWK{
Kid: "202101",
Kty: "RSA",
Alg: "RS256",
Use: "sig",
},
Kid: "202102",
Found: false,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Minute)
require.NotNil(t, cache)
require.NoError(t, cache.Add(ctx, tc.Key))

found, err := cache.Contains(ctx, tc.Kid)
require.NoError(t, err)

require.Equal(t, tc.Found, found)
})
}
}

func TestTTLCacheLen(t *testing.T) {
testCases := []struct {
Name string
Ops int
Len int
}{
{
Name: "OK",
Ops: 50,
Len: 50,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Second)
require.NotNil(t, cache)

for i := 0; i < tc.Ops; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Use: "sig",
}))
}

n, err := cache.Len(ctx)
require.NoError(t, err)
require.Equal(t, tc.Len, n)
})
}
}

func TestTTLCachePurge(t *testing.T) {
testCases := []struct {
Name string
Ops int
}{
{
Name: "OK",
Ops: 50,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Second)
require.NotNil(t, cache)

for i := 0; i < tc.Ops; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Use: "sig",
}))
}

require.NoError(t, cache.Purge(ctx))

n, err := cache.Len(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)
})
}
}

func TestTTLCacheCleanup(t *testing.T) {
ctx := context.Background()
cache := NewTTLCache(1 * time.Millisecond)

for i := 0; i < 10; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
require.NoError(t, cache.Add(&JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Expand All @@ -279,7 +152,7 @@ func TestTTLCacheCleanup(t *testing.T) {

time.Sleep(2 * time.Second)

n, err := cache.Len(ctx)
n, err := cache.Len()
require.NoError(t, err)
require.Equal(t, 0, n)
}
Loading

0 comments on commit 831ac26

Please sign in to comment.