Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Opaque Tokens #242

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,8 @@ issues:
text: "hugeParam: user is heavy"
linters:
- gocritic
- path: internal/oidctesting/
linters:
- funlen
- gocritic
- unparam
1 change: 1 addition & 0 deletions examples/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ func run(cfg shared.RuntimeConfig) error {
opts = []options.Option{
options.WithIssuer(cfg.Issuer),
options.WithFallbackSignatureAlgorithm(cfg.FallbackSignatureAlgorithm),
options.WithOpaqueTokensEnabled(),
}
claimsValidationFn := shared.GetOPTestClaimsValidationFn(cfg.RequiredOPTestClientId)
return getHandler(cfg, claimsValidationFn, opts...)
Expand Down
2 changes: 2 additions & 0 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@ require (
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/xenitab/go-oidc-middleware/oidcgin => ../oidcgin
2 changes: 2 additions & 0 deletions examples/op/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
func main() {
op, err := optest.New(
optest.WithIssuer("http://localhost:8082"),
optest.WithOpaqueAccessTokens(),
optest.WithoutAutoStart(),
optest.WithDefaultTestUser("test"),
optest.WithLoginPrompt(),
Expand Down Expand Up @@ -65,6 +66,7 @@ func main() {
r.Any("/token", gin.WrapH(opRouter))
r.Any("/jwks", gin.WrapH(opRouter))
r.Any("/login", gin.WrapH(opRouter))
r.Any("/userinfo", gin.WrapH(opRouter))

r.GET("/get_test_token", func(c *gin.Context) {
token, err := op.GetToken()
Expand Down
1 change: 1 addition & 0 deletions examples/pkce-cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func run(cfg config) error {
})

g.Go(func() error {
fmt.Println(startUrl)
return openUrlWithBrowser(startUrl)
})

Expand Down
13 changes: 5 additions & 8 deletions examples/shared/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,11 @@ func GetOktaClaimsValidationFn(requiredClientId string) options.ClaimsValidation
}

type OPTestClaims struct {
Audience []string `json:"aud"`
ExpiresAt time.Time `json:"exp"`
IssuedAt time.Time `json:"iat"`
Id string `json:"id"`
Issuer string `json:"iss"`
NotBefore time.Time `json:"nbf"`
Subject string `json:"sub"`
ClientId string `json:"client_id"`
Audience string `json:"aud"`
Id string `json:"id"`
Issuer string `json:"iss"`
Subject string `json:"sub"`
ClientId string `json:"client_id"`
}

func GetOPTestClaimsValidationFn(requiredClientId string) options.ClaimsValidationFn[OPTestClaims] {
Expand Down
21 changes: 0 additions & 21 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.20
require (
github.com/lestrrat-go/jwx v1.2.25
github.com/stretchr/testify v1.8.1
github.com/xenitab/dispans v0.0.10
go.uber.org/ratelimit v0.2.0
golang.org/x/sync v0.1.0
)
Expand All @@ -14,14 +13,7 @@ require (
github.com/andres-erbsen/clock v0.0.0-20160526145045-9e14626cd129 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 // indirect
github.com/go-oauth2/oauth2/v4 v4.4.2 // indirect
github.com/go-session/session v3.1.2+incompatible // indirect
github.com/goccy/go-json v0.10.0 // indirect
github.com/golang-jwt/jwt v3.2.2+incompatible // indirect
github.com/golang/protobuf v1.5.2 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/klauspost/compress v1.15.1 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/blackmagic v1.0.1 // indirect
Expand All @@ -31,20 +23,7 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/tidwall/btree v0.6.1 // indirect
github.com/tidwall/buntdb v1.2.7 // indirect
github.com/tidwall/gjson v1.11.0 // indirect
github.com/tidwall/grect v0.1.3 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/rtred v0.1.2 // indirect
github.com/tidwall/tinyqueue v0.1.1 // indirect
github.com/valyala/fasthttp v1.35.0 // indirect
golang.org/x/crypto v0.6.0 // indirect
golang.org/x/net v0.6.0 // indirect
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.28.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
493 changes: 0 additions & 493 deletions go.sum

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions internal/oidc/keyhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ func (h *keyHandler) waitForUpdateKeySetAndGetKeySet(ctx context.Context) (jwk.S
return result.keySet, result.err
}

//nolint:unparam // for some reason golangci-lint has started complaining about the jwk.Key with it being unused, which it is
func (h *keyHandler) waitForUpdateKeySetAndGetKey(ctx context.Context) (jwk.Key, error) {
keySet, err := h.waitForUpdateKeySetAndGetKeySet(ctx)
if err != nil {
Expand All @@ -119,6 +120,7 @@ func (h *keyHandler) waitForUpdateKeySetAndGetKey(ctx context.Context) (jwk.Key,

return key, nil
}

func (h *keyHandler) getKey(ctx context.Context, tokenKeyID string, tokenAlgorithm jwa.SignatureAlgorithm) (jwk.Key, error) {
if h.disableKeyID {
return h.getKeyWithoutKeyID(tokenAlgorithm)
Expand Down
8 changes: 5 additions & 3 deletions internal/oidc/keyhandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func TestNewKeyHandler(t *testing.T) {
op := optest.NewTesting(t)
issuer := op.GetURL(t)
discoveryUri := GetDiscoveryUriFromIssuer(issuer)
jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
metadata, err := GetOidcMetadataFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
jwksUri := metadata.JwksUri
require.NoError(t, err)

keyHandler, err := newKeyHandler(http.DefaultClient, jwksUri, 10*time.Millisecond, 100, false)
Expand Down Expand Up @@ -106,7 +107,8 @@ func TestUpdate(t *testing.T) {
op := optest.NewTesting(t)
issuer := op.GetURL(t)
discoveryUri := GetDiscoveryUriFromIssuer(issuer)
jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
metadata, err := GetOidcMetadataFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
jwksUri := metadata.JwksUri
require.NoError(t, err)

rateLimit := uint(10)
Expand Down Expand Up @@ -173,7 +175,7 @@ func TestUpdate(t *testing.T) {
stop := time.Now()
expectedStop := start.Add(time.Second / time.Duration(rateLimit))

require.WithinDuration(t, expectedStop, stop, 20*time.Millisecond)
require.WithinDuration(t, expectedStop, stop, 30*time.Millisecond)

require.Equal(t, 7, keyHandler.keyUpdateCount)
}
Expand Down
71 changes: 43 additions & 28 deletions internal/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ var (
errSignatureVerification = fmt.Errorf("failed to verify signature")
)

type handler[T any] struct {
type jwtHandler[T any] struct {
issuer string
discoveryUri string
discoveryFetchTimeout time.Duration
Expand All @@ -40,10 +40,24 @@ type handler[T any] struct {
claimsValidationFn options.ClaimsValidationFn[T]
}

func NewHandler[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) (*handler[T], error) {
type Handler[T any] interface {
ParseToken(ctx context.Context, tokenString string) (T, error)
SetIssuer(issuer string)
SetDiscoveryUri(discoveryUri string)
}

func NewHandler[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) (Handler[T], error) {
opts := options.New(setters...)
if opts.OpaqueTokensEnabled {
return newOpaqueHandler(claimsValidationFn, setters...)
}

return newjwtHandler(claimsValidationFn, setters...)
}

h := &handler[T]{
func newjwtHandler[T any](claimsValidationFn options.ClaimsValidationFn[T], setters ...options.Option) (*jwtHandler[T], error) {
opts := options.New(setters...)
h := &jwtHandler[T]{
issuer: opts.Issuer,
discoveryUri: opts.DiscoveryUri,
discoveryFetchTimeout: opts.DiscoveryFetchTimeout,
Expand Down Expand Up @@ -73,7 +87,7 @@ func NewHandler[T any](claimsValidationFn options.ClaimsValidationFn[T], setters

h.fallbackSignatureAlgorithm = alg
}
if !opts.LazyLoadJwks {
if !opts.LazyLoadMetadata {
err := h.loadJwks()
if err != nil {
return nil, fmt.Errorf("unable to load jwks: %w", err)
Expand All @@ -83,13 +97,16 @@ func NewHandler[T any](claimsValidationFn options.ClaimsValidationFn[T], setters
return h, nil
}

func (h *handler[T]) loadJwks() error {
func (h *jwtHandler[T]) loadJwks() error {
if h.jwksUri == "" {
jwksUri, err := getJwksUriFromDiscoveryUri(h.httpClient, h.discoveryUri, h.discoveryFetchTimeout)
metadata, err := GetOidcMetadataFromDiscoveryUri(h.httpClient, h.discoveryUri, h.discoveryFetchTimeout)
if err != nil {
return fmt.Errorf("unable to fetch jwksUri from discoveryUri (%s): %w", h.discoveryUri, err)
}
h.jwksUri = jwksUri
if metadata.JwksUri == "" {
return fmt.Errorf("JwksUri is empty")
}
h.jwksUri = metadata.JwksUri
}

keyHandler, err := newKeyHandler(h.httpClient, h.jwksUri, h.jwksFetchTimeout, h.jwksRateLimit, h.disableKeyID)
Expand All @@ -102,17 +119,17 @@ func (h *handler[T]) loadJwks() error {
return nil
}

func (h *handler[T]) SetIssuer(issuer string) {
func (h *jwtHandler[T]) SetIssuer(issuer string) {
h.issuer = issuer
}

func (h *handler[T]) SetDiscoveryUri(discoveryUri string) {
func (h *jwtHandler[T]) SetDiscoveryUri(discoveryUri string) {
h.discoveryUri = discoveryUri
}

type ParseTokenFunc[T any] func(ctx context.Context, tokenString string) (T, error)

func (h *handler[T]) ParseToken(ctx context.Context, tokenString string) (T, error) {
func (h *jwtHandler[T]) ParseToken(ctx context.Context, tokenString string) (T, error) {
if h.keyHandler == nil {
err := h.loadJwks()
if err != nil {
Expand Down Expand Up @@ -204,15 +221,15 @@ func (h *handler[T]) ParseToken(ctx context.Context, tokenString string) (T, err
return claims, nil
}

func (h *handler[T]) validateClaims(claims *T) error {
func (h *jwtHandler[T]) validateClaims(claims *T) error {
if h.claimsValidationFn == nil {
return nil
}

return h.claimsValidationFn(claims)
}

func (h *handler[T]) jwtTokenToClaims(ctx context.Context, token jwt.Token) (T, error) {
func (h *jwtHandler[T]) jwtTokenToClaims(ctx context.Context, token jwt.Token) (T, error) {
rawClaims, err := token.AsMap(ctx)
if err != nil {
return *new(T), fmt.Errorf("unable to convert token to claims: %w", err)
Expand All @@ -236,46 +253,44 @@ func GetDiscoveryUriFromIssuer(issuer string) string {
return fmt.Sprintf("%s/.well-known/openid-configuration", strings.TrimSuffix(issuer, "/"))
}

func getJwksUriFromDiscoveryUri(httpClient *http.Client, discoveryUri string, fetchTimeout time.Duration) (string, error) {
type oidcMetadata struct {
JwksUri string `json:"jwks_uri"`
UserinfoEndpoint string `json:"userinfo_endpoint"`
}

func GetOidcMetadataFromDiscoveryUri(httpClient *http.Client, discoveryUri string, fetchTimeout time.Duration) (oidcMetadata, error) {
ctx, cancel := context.WithTimeout(context.Background(), fetchTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, discoveryUri, nil)
if err != nil {
return "", err
return oidcMetadata{}, err
}

req.Header.Set("Accept", "application/json")

res, err := httpClient.Do(req)
if err != nil {
return "", err
return oidcMetadata{}, err
}

bodyBytes, err := io.ReadAll(res.Body)
if err != nil {
return "", err
return oidcMetadata{}, err
}

err = res.Body.Close()
if err != nil {
return "", err
return oidcMetadata{}, err
}

var discoveryData struct {
JwksUri string `json:"jwks_uri"`
}

err = json.Unmarshal(bodyBytes, &discoveryData)
metadata := oidcMetadata{}
err = json.Unmarshal(bodyBytes, &metadata)
if err != nil {
return "", err
}

if discoveryData.JwksUri == "" {
return "", fmt.Errorf("JwksUri is empty")
return oidcMetadata{}, err
}

return discoveryData.JwksUri, nil
return metadata, nil
}

func getKeyIDFromTokenHeader(tokenHeaders jws.Headers) (string, error) {
Expand Down
11 changes: 6 additions & 5 deletions internal/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ import (
"testing"
"time"

"github.com/xenitab/go-oidc-middleware/optest"
"github.com/xenitab/go-oidc-middleware/options"

"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jws"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/require"
"github.com/xenitab/dispans/server"
)

type testClaims map[string]interface{}
Expand Down Expand Up @@ -457,12 +457,13 @@ func TestIsTokenTypeValid(t *testing.T) {
}

func TestGetAndValidateTokenFromString(t *testing.T) {
op := server.NewTesting(t)
op := optest.NewTesting(t)
defer op.Close(t)

issuer := op.GetURL(t)
discoveryUri := GetDiscoveryUriFromIssuer(issuer)
jwksUri, err := getJwksUriFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
metadata, err := GetOidcMetadataFromDiscoveryUri(http.DefaultClient, discoveryUri, 10*time.Millisecond)
jwksUri := metadata.JwksUri
require.NoError(t, err)

keyHandler, err := newKeyHandler(http.DefaultClient, jwksUri, 50*time.Millisecond, 100, false)
Expand All @@ -474,7 +475,7 @@ func TestGetAndValidateTokenFromString(t *testing.T) {
validAccessToken := op.GetToken(t).AccessToken
require.NotEmpty(t, validAccessToken)

validIDToken, ok := op.GetToken(t).Extra("id_token").(string)
validIDToken := op.GetToken(t).IdToken
require.True(t, ok)
require.NotEmpty(t, validIDToken)

Expand Down Expand Up @@ -612,7 +613,7 @@ func TestParseToken(t *testing.T) {
options.WithJwksUri(testServer.URL),
options.WithDisableKeyID(true),
options.WithJwksRateLimit(100),
options.WithLazyLoadJwks(true),
options.WithLazyLoadMetadata(true),
},
numKeys: 2,
expectedErrorContains: "keyID is disabled, but received a keySet with more than one key",
Expand Down
Loading