Skip to content

Commit

Permalink
add caching of STS-exchanged access tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
DrDaveD committed Jul 22, 2024
1 parent ed4005c commit b3f6ff4
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 64 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,12 @@ the corresponding `creds/:name` path.
#### `GET` (`read`)

Retrieve a new access token by performing a token exchange request on demand.
The token exchange operation always sends the access token from the
The token exchange operation sends the access token from the
corresponding credential as the subject token and explicitly requests a new
access token from the authorization server.
Reuses previous token that was made with the same parameters
if the provider specified an expiration time
and the token is not yet expired or close to it.

Parameters:

Expand All @@ -502,6 +505,7 @@ Parameters:
| `scopes` | A list of explicit scopes to request. | List of String | None | No |
| `audiences` | A list of explicit audiences to request. | List of String | None | No |
| `resources` | A list of explicit resources to request. | List of String | None | No |
| `minimum_seconds` | Minimum additional duration to require the access token to be valid for. | Integer | 10<sup id="ret-3-b">[3](#footnote-3)</sup> | No |

## Providers

Expand Down
2 changes: 1 addition & 1 deletion pkg/backend/path_creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (b *backend) credsReadOperation(ctx context.Context, req *logical.Request,
}

return logical.ErrorResponse("token pending issuance"), nil
case !b.tokenValid(entry.Token, expiryDelta):
case !b.tokenValid(entry.Token.Token, expiryDelta):
if entry.AuthServerError != "" {
return logical.ErrorResponse("server %q has configuration problems: %s", entry.AuthServerName, entry.AuthServerError), nil
} else if entry.UserError != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/backend/path_self.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (b *backend) selfReadOperation(ctx context.Context, req *logical.Request, d
return nil, err
case entry == nil:
return nil, nil
case !b.tokenValid(entry.Token, expiryDelta):
case !b.tokenValid(entry.Token.Token, expiryDelta):
return logical.ErrorResponse("token expired"), nil
}

Expand Down
85 changes: 64 additions & 21 deletions pkg/backend/path_sts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ import (
"context"
"fmt"
"strings"
"time"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/puppetlabs/leg/errmap/pkg/errmap"
"github.com/puppetlabs/leg/errmap/pkg/errmark"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/persistence"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider"
"golang.org/x/oauth2"
)

func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) {
keyer := persistence.AuthCodeName(data.Get("name").(string))
entry, err := b.getRefreshCredToken(
ctx,
req.Storage,
persistence.AuthCodeName(data.Get("name").(string)),
keyer,
defaultExpiryDelta,
)
switch {
Expand All @@ -33,7 +36,7 @@ func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, da
}

return logical.ErrorResponse("token pending issuance"), nil
case !b.tokenValid(entry.Token, defaultExpiryDelta):
case !b.tokenValid(entry.Token.Token, defaultExpiryDelta):
if entry.AuthServerError != "" {
return logical.ErrorResponse("server %q has configuration problems: %s", entry.AuthServerName, entry.AuthServerError), nil
} else if entry.UserError != "" {
Expand All @@ -43,26 +46,60 @@ func (b *backend) stsReadOperation(ctx context.Context, req *logical.Request, da
return logical.ErrorResponse("token expired"), nil
}

ops, put, err := b.getProviderOperations(ctx, req.Storage, persistence.AuthServerName(entry.AuthServerName), defaultExpiryDelta)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(fmt.Errorf("server %q has configuration problems: %w", entry.AuthServerName, errmark.MarkShort(err)).Error()), nil
} else if err != nil {
return nil, err
}
defer put()
scopes := data.Get("scopes").([]string)
audiences := data.Get("audiences").([]string)
resources := data.Get("resources").([]string)
exchangeKey := "scopes=" + strings.Join(scopes, " ") +
",audiences=" + strings.Join(audiences, " ") +
",resources=" + strings.Join(resources, " ")
expiryDelta := time.Duration(data.Get("minimum_seconds").(int)) * time.Second

tok, err := ops.TokenExchange(
ctx,
entry.Token,
provider.WithScopes(data.Get("scopes").([]string)),
provider.WithAudiences(data.Get("audiences").([]string)),
provider.WithResources(data.Get("resources").([]string)),
provider.WithProviderOptions(entry.ProviderOptions),
)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(errmap.Wrap(errmark.MarkShort(err), "exchange failed").Error()), nil
} else if err != nil {
return nil, err
tok, ok := entry.ExchangedTokens[exchangeKey]
if !ok || !b.tokenValid(tok, expiryDelta) {
ops, put, err := b.getProviderOperations(ctx, req.Storage, persistence.AuthServerName(entry.AuthServerName), defaultExpiryDelta)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(fmt.Errorf("server %q has configuration problems: %w", entry.AuthServerName, errmark.MarkShort(err)).Error()), nil
} else if err != nil {
return nil, err
}
defer put()

exchangedTok, err := ops.TokenExchange(
ctx,
entry.Token,
provider.WithScopes(scopes),
provider.WithAudiences(audiences),
provider.WithResources(resources),
provider.WithProviderOptions(entry.ProviderOptions),
)
if errmark.MarkedUser(err) {
return logical.ErrorResponse(errmap.Wrap(errmark.MarkShort(err), "exchange failed").Error()), nil
} else if err != nil {
return nil, err
}
if !b.tokenValid(exchangedTok.Token, expiryDelta) {
return logical.ErrorResponse("token expired"), nil
}

// copy into smaller struct for caching
tok = &oauth2.Token{
AccessToken: exchangedTok.Token.AccessToken,
TokenType: exchangedTok.Token.TokenType,
Expiry: exchangedTok.Token.Expiry,
}

if !tok.Expiry.IsZero() {
// Cache the token since it has an expiration time
err = b.storeExchangedToken(
ctx,
req.Storage,
keyer,
exchangeKey,
tok)
if err != nil {
return nil, err
}
}
}

rd := map[string]interface{}{
Expand Down Expand Up @@ -103,6 +140,12 @@ var stsFields = map[string]*framework.FieldSchema{
Description: "Specifies the target RFC 8707 resource indicators for the minted token.",
Query: true,
},
"minimum_seconds": {
Type: framework.TypeDurationSecond,
Description: "Minimum remaining seconds to allow when reusing exchanged access token.",
Default: 0,
Query: true,
},
}

const stsHelpSynopsis = `
Expand Down
6 changes: 3 additions & 3 deletions pkg/backend/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import (
"time"

"github.com/puppetlabs/leg/timeutil/pkg/clock"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider"
"golang.org/x/oauth2"
)

const (
defaultExpiryDelta = 10 * time.Second
)

func tokenExpired(clk clock.Clock, t *provider.Token, expiryDelta time.Duration) bool {
func tokenExpired(clk clock.Clock, t *oauth2.Token, expiryDelta time.Duration) bool {
if t.Expiry.IsZero() {
return false
}
Expand All @@ -23,6 +23,6 @@ func tokenExpired(clk clock.Clock, t *provider.Token, expiryDelta time.Duration)
return t.Expiry.Round(0).Add(-expiryDelta).Before(clk.Now())
}

func (b *backend) tokenValid(tok *provider.Token, expiryDelta time.Duration) bool {
func (b *backend) tokenValid(tok *oauth2.Token, expiryDelta time.Duration) bool {
return tok != nil && tok.AccessToken != "" && !tokenExpired(b.clock, tok, expiryDelta)
}
38 changes: 36 additions & 2 deletions pkg/backend/token_authcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/puppetlabs/leg/timeutil/pkg/retry"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/persistence"
"github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/provider"
"golang.org/x/oauth2"
)

type refreshProcess struct {
Expand Down Expand Up @@ -110,7 +111,7 @@ func (b *backend) refreshCredToken(ctx context.Context, storage logical.Storage,
switch {
case err != nil || candidate == nil:
return err
case !candidate.TokenIssued() || b.tokenValid(candidate.Token, expiryDelta) || candidate.RefreshToken == "":
case !candidate.TokenIssued() || b.tokenValid(candidate.Token.Token, expiryDelta) || candidate.RefreshToken == "":
entry = candidate
return nil
}
Expand Down Expand Up @@ -155,9 +156,42 @@ func (b *backend) getRefreshCredToken(ctx context.Context, storage logical.Stora
return nil, err
case entry == nil:
return nil, nil
case !entry.TokenIssued() || b.tokenValid(entry.Token, expiryDelta):
case !entry.TokenIssued() || b.tokenValid(entry.Token.Token, expiryDelta):
return entry, nil
default:
return b.refreshCredToken(ctx, storage, keyer, expiryDelta)
}
}

func (b *backend) storeExchangedToken(ctx context.Context, storage logical.Storage, keyer persistence.AuthCodeKeyer, exchangeKey string, tok *oauth2.Token) error {
ctx = clockctx.WithClock(ctx, b.clock)

err := b.data.AuthCode.WithLock(keyer, func(ach *persistence.LockedAuthCodeHolder) error {
acm := ach.Manager(storage)

entry, err := acm.ReadAuthCodeEntry(ctx)
if err != nil || entry == nil {
return err
}

if entry.ExchangedTokens == nil {
// first time, make the map
entry.ExchangedTokens = make(map[string]*oauth2.Token)
} else {
// remove every expired exchanged token while we're here
for k, t := range entry.ExchangedTokens {
if !b.tokenValid(t, defaultExpiryDelta) {
delete(entry.ExchangedTokens, k)
}
}
}
entry.ExchangedTokens[exchangeKey] = tok

if err := acm.WriteAuthCodeEntry(ctx, entry); err != nil {
return err
}

return nil
})
return err
}
Loading

0 comments on commit b3f6ff4

Please sign in to comment.