diff --git a/path_oidc.go b/path_oidc.go index 6c5a9ee2..263c9d3d 100644 --- a/path_oidc.go +++ b/path_oidc.go @@ -13,6 +13,7 @@ import ( "github.com/coreos/go-oidc" "github.com/hashicorp/errwrap" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/cidrutil" "github.com/hashicorp/vault/sdk/helper/strutil" @@ -215,6 +216,7 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, } var rawToken string + var refreshToken string var oauth2Token *oauth2.Token code := d.Get("code").(string) @@ -239,6 +241,13 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, if !ok { return logical.ErrorResponse(errTokenVerification + " No id_token found in response."), nil } + + if role.RefreshStorePath != "" { + refreshToken, ok = oauth2Token.Extra("refresh_token").(string) + if !ok { + return logical.ErrorResponse(errTokenVerification + " No refresh_token found in response."), nil + } + } } if role.VerboseOIDCLogging { @@ -277,6 +286,9 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, } else { b.Logger().Debug("OIDC provider response", "marshalling error", err.Error()) } + if refreshToken != "" { + b.Logger().Debug("OIDC provider response", "refresh_token", refreshToken) + } } if err := validateBoundClaims(b.Logger(), role.BoundClaimsType, role.BoundClaims, allClaims); err != nil { @@ -314,6 +326,42 @@ func (b *jwtAuthBackend) pathCallback(ctx context.Context, req *logical.Request, role.PopulateTokenAuth(auth) + if role.RefreshStorePath != "" && refreshToken != "" { + vaultConfig := api.DefaultConfig() + client, err := api.NewClient(vaultConfig) + if err != nil { + return logical.ErrorResponse("error getting vault client: %s", err.Error()), nil + } + if role.RefreshStoreCred != "" { + client.SetToken(role.RefreshStoreCred) + } + storePath := role.RefreshStorePath + for { + /* replace claims in storePath of form {{claim}} with its value */ + start := strings.Index(storePath, "{{") + if start == -1 { + break + } + end := strings.Index(storePath, "}}") + if end < start { + return logical.ErrorResponse("mismatched brackets in refresh_store_path %s", role.RefreshStorePath), nil + } + claim := storePath[start+2:end] + if val, ok := allClaims[claim]; ok { + storePath = strings.ReplaceAll(storePath, "{{" + claim + "}}", val.(string)) + } else { + return logical.ErrorResponse("no claim %s found for refresh_store_path %s", claim, role.RefreshStorePath), nil + } + } + _, err = client.Logical().Write(storePath, map[string]interface{} { + "refresh_token": refreshToken, + }) + if err != nil { + return logical.ErrorResponse("error storing refresh token at %s: %s", storePath, err.Error()), nil + } + b.Logger().Debug("OIDC stored refresh token", "store_path", storePath) + } + resp := &logical.Response{ Auth: auth, } diff --git a/path_role.go b/path_role.go index 8216b7e7..f845fc1c 100644 --- a/path_role.go +++ b/path_role.go @@ -144,6 +144,14 @@ Defaults to 60 (1 minute) if set to 0 and can be disabled if set to -1.`, Not recommended in production since sensitive information may be present in OIDC responses.`, }, + "refresh_store_path": { + Type: framework.TypeString, + Description: `Vault path to store refresh token`, + }, + "refresh_store_cred": { + Type: framework.TypeString, + Description: `Vault credential (that is, token) to use when storing refresh token`, + }, }, ExistenceCheck: b.pathRoleExistenceCheck, Operations: map[logical.Operation]framework.OperationHandler{ @@ -202,6 +210,8 @@ type jwtRole struct { OIDCScopes []string `json:"oidc_scopes"` AllowedRedirectURIs []string `json:"allowed_redirect_uris"` VerboseOIDCLogging bool `json:"verbose_oidc_logging"` + RefreshStorePath string `json:"refresh_store_path"` + RefreshStoreCred string `json:"refresh_store_cred"` // Deprecated by TokenParams Policies []string `json:"policies"` @@ -308,6 +318,8 @@ func (b *jwtAuthBackend) pathRoleRead(ctx context.Context, req *logical.Request, "allowed_redirect_uris": role.AllowedRedirectURIs, "oidc_scopes": role.OIDCScopes, "verbose_oidc_logging": role.VerboseOIDCLogging, + "refresh_store_path": role.RefreshStorePath, + "refresh_store_cred": role.RefreshStoreCred, } role.PopulateTokenData(d) @@ -441,6 +453,13 @@ func (b *jwtAuthBackend) pathRoleCreateUpdate(ctx context.Context, req *logical. role.VerboseOIDCLogging = verboseOIDCLoggingRaw.(bool) } + if refreshStorePath, ok := data.GetOk("refresh_store_path"); ok { + role.RefreshStorePath = refreshStorePath.(string) + } + if refreshStoreCred, ok := data.GetOk("refresh_store_cred"); ok { + role.RefreshStoreCred = refreshStoreCred.(string) + } + boundClaimsType := data.Get("bound_claims_type").(string) switch boundClaimsType { case boundClaimsTypeString, boundClaimsTypeGlob: