From 68739e580efd311eb8b6bd33e5484da476f735c0 Mon Sep 17 00:00:00 2001 From: Francisco Santiago Date: Thu, 26 Oct 2017 14:05:20 +0200 Subject: [PATCH] Added User Info endoint for oidc connectors --- Documentation/oidc-connector.md | 6 ++ connector/connector.go | 5 ++ connector/oidc/oidc.go | 71 +++++++++++++++++++++++ examples/config-dev.yaml | 1 + server/handlers.go | 78 +++++++++++++++++++++++++ server/server.go | 1 + storage/conformance/conformance.go | 93 +++++++++++++++++++++++++++++- storage/kubernetes/storage.go | 34 +++++++++++ storage/kubernetes/storage_test.go | 1 + storage/kubernetes/types.go | 67 +++++++++++++++++++++ storage/memory/memory.go | 41 +++++++++++++ storage/sql/config_test.go | 1 + storage/sql/crud.go | 45 +++++++++++++++ storage/sql/migrate.go | 7 +++ storage/storage.go | 20 +++++++ 15 files changed, 469 insertions(+), 2 deletions(-) diff --git a/Documentation/oidc-connector.md b/Documentation/oidc-connector.md index 8171bc373e..a525d81ea2 100644 --- a/Documentation/oidc-connector.md +++ b/Documentation/oidc-connector.md @@ -42,6 +42,12 @@ connectors: # following field. # # basicAuthUnsupported: true + + # Some clients require the possibility to get further user details provided + # by via the UserInfo endpoint. + # + # See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfo + # userInfo: https://www.googleapis.com/oauth2/v3/userinfo # Google supports whitelisting allowed domains when using G Suite # (Google Apps). The following field can be set to a list of domains diff --git a/connector/connector.go b/connector/connector.go index bc5f3b18a5..5e98ca4216 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -65,6 +65,11 @@ type CallbackConnector interface { HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) } +// UserInfoConnector represents connectors that support the user info endpoint +type UserInfoConnector interface { + GetUserInfo(connData []byte) (user map[string]interface{}, err error) +} + // SAMLConnector represents SAML connectors which implement the HTTP POST binding. // RelayState is handled by the server. // diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index f0d8daf704..a0d41b48a6 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -9,6 +9,8 @@ import ( "net/url" "strings" "sync" + "encoding/json" + "github.com/coreos/go-oidc" "github.com/sirupsen/logrus" @@ -17,6 +19,17 @@ import ( "github.com/coreos/dex/connector" ) +type connectorData struct { + // Store token, even if it eventually expires + AccessToken string `json:"accessToken"` + TokenType string `json:"tokenType"` +} + +type connectorData2 struct { + // Store token, even if it eventually expires + AccessToken string `json:"accessToken"` +} + // Config holds configuration options for OpenID Connect logins. type Config struct { Issuer string `json:"issuer"` @@ -36,6 +49,8 @@ type Config struct { // Optional list of whitelisted domains when using Google // If this field is nonempty, only users from a listed domain will be allowed to log in HostedDomains []string `json:"hostedDomains"` + // Userinfo endpoint for Relying Parties that need details about the authenticated user + UserInfo string `json:"userInfo"` } // Domains that don't support basic auth. golang.org/x/oauth2 has an internal @@ -116,6 +131,7 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn logger: logger, cancel: cancel, hostedDomains: c.HostedDomains, + userInfo: c.UserInfo, }, nil } @@ -132,6 +148,7 @@ type oidcConnector struct { cancel context.CancelFunc logger logrus.FieldLogger hostedDomains []string + userInfo string } func (c *oidcConnector) Close() error { @@ -215,9 +232,63 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide Email: claims.Email, EmailVerified: claims.EmailVerified, } + + // Add AccessToken to user identity for future requests + data := connectorData{ + TokenType: "bearer", + AccessToken: token.AccessToken, + } + + connData, err := json.Marshal(data) + if err != nil { + return identity, fmt.Errorf("marshal connector data: %v", err) + } + + identity.ConnectorData = connData + return identity, nil } +func (c *oidcConnector) GetUserInfo(connData []byte) (user map[string]interface{}, err error) { + user = map[string]interface{}{} + + // Extract Access Token that was stored while the connector handled the OIDC handshake + var cData connectorData + + if err := json.Unmarshal(connData, &cData); err != nil { + c.logger.Errorf("Failed to read connector data : %v", err) + return user, err + } + + // Format Authorization header to request user info on behalf of the client + authorization := fmt.Sprintf("%s %s", cData.TokenType, cData.AccessToken) + // Prepare get request including Authorization header from RP + req, err := http.NewRequest("GET", c.userInfo, nil) + + if err != nil { + fmt.Errorf("Error Creating GET request: %v", err) + return user, err + } + + req.Header.Add("Accept", `application/json`) + req.Header.Add("Content-Type", `application/json`) + req.Header.Add("Authorization", authorization) + + // Prepare http client to execute get request + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + fmt.Errorf("Error Executing GET request: %v", err) + return user, err + } + + json.NewDecoder(resp.Body).Decode(&user) + + + return user, nil +} + // Refresh is implemented for backwards compatibility, even though it's a no-op. func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { return identity, nil diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 542c0caeb2..c6d9aada10 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -63,6 +63,7 @@ connectors: # name: Google # config: # issuer: https://accounts.google.com +# userInfo: https://www.googleapis.com/oauth2/v3/userinfo # # Connector config values starting with a "$" will read from the environment. # clientID: $GOOGLE_CLIENT_ID # clientSecret: $GOOGLE_CLIENT_SECRET diff --git a/server/handlers.go b/server/handlers.go index 345cd49616..f9ffec1c4c 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -95,6 +95,7 @@ type discovery struct { Auth string `json:"authorization_endpoint"` Token string `json:"token_endpoint"` Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` ResponseTypes []string `json:"response_types_supported"` Subjects []string `json:"subject_types_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` @@ -109,6 +110,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { Auth: s.absURL("/auth"), Token: s.absURL("/token"), Keys: s.absURL("/keys"), + UserInfo: s.absURL("/userinfo"), Subjects: []string{"public"}, IDTokenAlgs: []string{string(jose.RS256)}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, @@ -656,7 +658,24 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } + // Random token to be handed to the client as Access Token accessToken := storage.NewID() + + // Create and store connector information under this access token + // so it can be used for future requests on behalf of the client + newAccessToken := storage.AccessToken{ + ID: accessToken, + ConnectorData: authCode.ConnectorData, + ConnectorID: authCode.ConnectorID, + Expiry: s.now().Add(time.Minute * 30), + } + + if err := s.storage.CreateAccessToken(newAccessToken); err != nil { + s.logger.Errorf("Failed to create access token: %v", err) + s.renderError(w, http.StatusInternalServerError, "Internal server error.") + return + } + idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) @@ -994,6 +1013,65 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r w.Write(data) } +// Request userinfo data from the original IdP on behalf of the Relying Party (client) +func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { + + // Extract token type and token value from Authorization header + authorization_substrs := strings.Split(r.Header.Get("Authorization"), " ") + clientAccessToken := authorization_substrs[1] + + // Retrieve previously saved actual Access Token to allow requesting user info from upstream source + accessToken, err := s.storage.GetAccessToken(clientAccessToken) + if err != nil || s.now().After(accessToken.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get access token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + } else { + s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired access token.", http.StatusBadRequest) + } + return + } + + // Delegate the fetching of User info data to the oidc connector + conn, err := s.getConnector(accessToken.ConnectorID) + + if err != nil { + s.logger.Errorf("Failed to get connector with id %q : %v", accessToken.ConnectorID, err) + s.renderError(w, http.StatusBadRequest, "Requested resource does not exist") + return + } + + var user map[string]interface{} + + switch conn := conn.Connector.(type) { + case connector.UserInfoConnector: + + lala := map[string]interface{}{} + if err := json.Unmarshal(accessToken.ConnectorData, &lala); err != nil { + s.logger.Errorf("Failed to read lala data : %v", err) + } + + user, err = conn.GetUserInfo(accessToken.ConnectorData) + default: + s.renderError(w, http.StatusBadRequest, "Requested resource does not exist.") + } + + s.writeUserInfo(w, user) +} + +func (s *Server) writeUserInfo(w http.ResponseWriter, user map[string]interface{}) { + + data, err := json.Marshal(user) + if err != nil { + s.logger.Errorf("failed to marshal access token response: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.Write(data) +} + func (s *Server) renderError(w http.ResponseWriter, status int, description string) { if err := s.templates.err(w, http.StatusText(status), description); err != nil { s.logger.Errorf("Server template error: %v", err) diff --git a/server/server.go b/server/server.go index 65de3b834c..8e1c2e2b65 100644 --- a/server/server.go +++ b/server/server.go @@ -239,6 +239,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // TODO(ericchiang): rate limit certain paths based on IP. handleWithCORS("/token", s.handleToken) handleWithCORS("/keys", s.handlePublicKeys) + handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 0bfc474566..8c7b1c9d6b 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -42,6 +42,7 @@ func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest) func RunTests(t *testing.T, newStorage func() storage.Storage) { runTests(t, newStorage, []subTest{ {"AuthCodeCRUD", testAuthCodeCRUD}, + {"AccessTokenCRUD", testAccessTokenCRUD}, {"AuthRequestCRUD", testAuthRequestCRUD}, {"ClientCRUD", testClientCRUD}, {"RefreshTokenCRUD", testRefreshTokenCRUD}, @@ -238,6 +239,57 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { mustBeErrNotFound(t, "auth code", err) } +func testAccessTokenCRUD(t *testing.T, s storage.Storage) { + a1 := storage.AccessToken{ + ID: storage.NewID(), + ConnectorData: []byte(`{"some":"data"}`), + ConnectorID: "ldap", + Expiry: neverExpire, + } + + if err := s.CreateAccessToken(a1); err != nil { + t.Fatalf("failed creating access token: %v", err) + } + + a2 := storage.AccessToken{ + ID: storage.NewID(), + ConnectorData: []byte(`{"some":"data"}`), + ConnectorID: "ldap", + Expiry: neverExpire, + } + + // Attempt to create same AccessToken twice. + err := s.CreateAccessToken(a1) + mustBeErrAlreadyExists(t, "access token", err) + + if err := s.CreateAccessToken(a2); err != nil { + t.Fatalf("failed creating access token: %v", err) + } + + got, err := s.GetAccessToken(a1.ID) + if err != nil { + t.Fatalf("failed to get access token: %v", err) + } + if a1.Expiry.Unix() != got.Expiry.Unix() { + t.Errorf("access token expiry did not match want=%s vs got=%s", a1.Expiry, got.Expiry) + } + got.Expiry = a1.Expiry // time fields do not compare well + if diff := pretty.Compare(a1, got); diff != "" { + t.Errorf("access token retrieved from storage did not match: %s", diff) + } + + if err := s.DeleteAccessToken(a1.ID); err != nil { + t.Fatalf("delete access token: %v", err) + } + + if err := s.DeleteAccessToken(a2.ID); err != nil { + t.Fatalf("delete access token: %v", err) + } + + _, err = s.GetAccessToken(a1.ID) + mustBeErrNotFound(t, "access token", err) +} + func testClientCRUD(t *testing.T, s storage.Storage) { id1 := storage.NewID() c1 := storage.Client{ @@ -731,7 +783,7 @@ func testGC(t *testing.T, s storage.Storage) { if err != nil { t.Errorf("garbage collection failed: %v", err) } else { - if result.AuthCodes != 0 || result.AuthRequests != 0 { + if result.AuthCodes != 0 || result.AccessTokens != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } } @@ -752,6 +804,43 @@ func testGC(t *testing.T, s storage.Storage) { t.Errorf("expected storage.ErrNotFound, got %v", err) } + b := storage.AccessToken{ + ID: storage.NewID(), + ConnectorData: []byte(`{"some":"data"}`), + ConnectorID: "ldap", + Expiry: expiry, + } + + if err := s.CreateAccessToken(b); err != nil { + t.Fatalf("failed creating access token: %v", err) + } + + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.AccessTokens != 0 || result.AccessTokens != 0 || result.AuthRequests != 0 { + t.Errorf("expected no garbage collection results, got %#v", result) + } + } + if _, err := s.GetAccessToken(b.ID); err != nil { + t.Errorf("expected to be able to get access token after GC: %v", err) + } + } + + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.AccessTokens != 1 { + t.Errorf("expected to garbage collect 1 objects, got %d", r.AccessTokens) + } + + if _, err := s.GetAccessToken(b.ID); err == nil { + t.Errorf("expected access token to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } + a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", @@ -783,7 +872,7 @@ func testGC(t *testing.T, s storage.Storage) { if err != nil { t.Errorf("garbage collection failed: %v", err) } else { - if result.AuthCodes != 0 || result.AuthRequests != 0 { + if result.AuthCodes != 0 || result.AccessTokens != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 1cb0fd078c..267318a9a0 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -14,6 +14,7 @@ import ( const ( kindAuthCode = "AuthCode" + kindAccessToken = "AccessToken" kindAuthRequest = "AuthRequest" kindClient = "OAuth2Client" kindRefreshToken = "RefreshToken" @@ -25,6 +26,7 @@ const ( const ( resourceAuthCode = "authcodes" + resourceAccessToken = "accesstokens" resourceAuthRequest = "authrequests" resourceClient = "oauth2clients" resourceRefreshToken = "refreshtokens" @@ -184,6 +186,10 @@ func (cli *client) CreateAuthCode(c storage.AuthCode) error { return cli.post(resourceAuthCode, cli.fromStorageAuthCode(c)) } +func (cli *client) CreateAccessToken(c storage.AccessToken) error { + return cli.post(resourceAccessToken, cli.fromStorageAccessToken(c)) +} + func (cli *client) CreatePassword(p storage.Password) error { return cli.post(resourcePassword, cli.fromStoragePassword(p)) } @@ -216,6 +222,14 @@ func (cli *client) GetAuthCode(id string) (storage.AuthCode, error) { return toStorageAuthCode(code), nil } +func (cli *client) GetAccessToken(id string) (storage.AccessToken, error) { + var token AccessToken + if err := cli.get(resourceAccessToken, id, &token); err != nil { + return storage.AccessToken{}, err + } + return toStorageAccessToken(token), nil +} + func (cli *client) GetClient(id string) (storage.Client, error) { c, err := cli.getClient(id) if err != nil { @@ -355,6 +369,10 @@ func (cli *client) DeleteAuthCode(code string) error { return cli.delete(resourceAuthCode, code) } +func (cli *client) DeleteAccessToken(token string) error { + return cli.delete(resourceAccessToken, token) +} + func (cli *client) DeleteClient(id string) error { // Check for hash collition. c, err := cli.getClient(id) @@ -550,5 +568,21 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e result.AuthCodes++ } } + + var accessTokens AccessTokenList + if err := cli.list(resourceAccessToken, &accessTokens); err != nil { + return result, fmt.Errorf("failed to list access tokens: %v", err) + } + + for _, accessToken := range accessTokens.AccessTokens { + if now.After(accessToken.Expiry) { + if err := cli.delete(resourceAccessToken, accessToken.ObjectMeta.Name); err != nil { + cli.logger.Errorf("failed to delete access tokens %v", err) + delErr = fmt.Errorf("failed to delete access tokens: %v", err) + } + result.AccessTokens++ + } + } + return result, delErr } diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index 58340f7057..a94e5e3782 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -88,6 +88,7 @@ func TestStorage(t *testing.T) { newStorage := func() storage.Storage { for _, resource := range []string{ resourceAuthCode, + resourceAccessToken, resourceAuthRequest, resourceClient, resourceRefreshToken, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index aa2965312f..d16898551c 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -26,6 +26,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{ Description: "A code which can be claimed for an access token.", Versions: []k8sapi.APIVersion{{Name: "v1"}}, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "access-token.oidc.coreos.com", + }, + TypeMeta: tprMeta, + Description: "A token which can be used to access resources.", + Versions: []k8sapi.APIVersion{{Name: "v1"}}, + }, { ObjectMeta: k8sapi.ObjectMeta{ Name: "auth-request.oidc.coreos.com", @@ -109,6 +117,21 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{ }, }, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "accesstokens.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: "v1", + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "accesstokens", + Singular: "accesstoken", + Kind: "AccessToken", + }, + }, + }, { ObjectMeta: k8sapi.ObjectMeta{ Name: "authrequests.dex.coreos.com", @@ -507,6 +530,50 @@ func toStorageAuthCode(a AuthCode) storage.AuthCode { } } +// AccessToken is a mirrored struct from storage with JSON struct tags and +// Kubernetes type metadata. +type AccessToken struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + ConnectorData []byte `json:"connectorData,omitempty"` + ConnectorID string `json:"connectorID,omitempty"` + + Expiry time.Time `json:"expiry"` +} + +// AccessTokenList is a list of AccessTokens. +type AccessTokenList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + AccessTokens []AccessToken `json:"items"` +} + +func (cli *client) fromStorageAccessToken(a storage.AccessToken) AccessToken { + return AccessToken{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindAccessToken, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: a.ID, + Namespace: cli.namespace, + }, + ConnectorData: a.ConnectorData, + ConnectorID: a.ConnectorID, + Expiry: a.Expiry, + } +} + +func toStorageAccessToken(a AccessToken) storage.AccessToken { + return storage.AccessToken{ + ID: a.ObjectMeta.Name, + ConnectorData: a.ConnectorData, + ConnectorID: a.ConnectorID, + Expiry: a.Expiry, + } +} + // RefreshToken is a mirrored struct from storage with JSON struct tags and // Kubernetes type metadata. type RefreshToken struct { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index ed80778b63..a4c889a482 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -15,6 +15,7 @@ func New(logger logrus.FieldLogger) storage.Storage { return &memStorage{ clients: make(map[string]storage.Client), authCodes: make(map[string]storage.AuthCode), + accessTokens: make(map[string]storage.AccessToken), refreshTokens: make(map[string]storage.RefreshToken), authReqs: make(map[string]storage.AuthRequest), passwords: make(map[string]storage.Password), @@ -41,6 +42,7 @@ type memStorage struct { clients map[string]storage.Client authCodes map[string]storage.AuthCode + accessTokens map[string]storage.AccessToken refreshTokens map[string]storage.RefreshToken authReqs map[string]storage.AuthRequest passwords map[string]storage.Password @@ -79,6 +81,12 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err result.AuthRequests++ } } + for id, a := range s.accessTokens { + if now.After(a.Expiry) { + delete(s.accessTokens, id) + result.AccessTokens++ + } + } }) return result, nil } @@ -105,6 +113,17 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { return } +func (s *memStorage) CreateAccessToken(c storage.AccessToken) (err error) { + s.tx(func() { + if _, ok := s.accessTokens[c.ID]; ok { + err = storage.ErrAlreadyExists + } else { + s.accessTokens[c.ID] = c + } + }) + return +} + func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { s.tx(func() { if _, ok := s.refreshTokens[r.ID]; ok { @@ -176,6 +195,17 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { return } +func (s *memStorage) GetAccessToken(id string) (c storage.AccessToken, err error) { + s.tx(func() { + var ok bool + if c, ok = s.accessTokens[id]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + func (s *memStorage) GetPassword(email string) (p storage.Password, err error) { email = strings.ToLower(email) s.tx(func() { @@ -330,6 +360,17 @@ func (s *memStorage) DeleteAuthCode(id string) (err error) { return } +func (s *memStorage) DeleteAccessToken(id string) (err error) { + s.tx(func() { + if _, ok := s.accessTokens[id]; !ok { + err = storage.ErrNotFound + return + } + delete(s.accessTokens, id) + }) + return +} + func (s *memStorage) DeleteAuthRequest(id string) (err error) { s.tx(func() { if _, ok := s.authReqs[id]; !ok { diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go index 13b2508e4c..8bc9e4e84a 100644 --- a/storage/sql/config_test.go +++ b/storage/sql/config_test.go @@ -35,6 +35,7 @@ func cleanDB(c *conn) error { delete from client; delete from auth_request; delete from auth_code; + delete from access_token; delete from refresh_token; delete from keys; delete from password; diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 17886b91d0..fbcf7aabfa 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -100,6 +100,14 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error if n, err := r.RowsAffected(); err == nil { result.AuthCodes = n } + + r, err = c.Exec(`delete from access_token where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc access_token: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.AccessTokens = n + } return } @@ -248,6 +256,42 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { return a, nil } +func (c *conn) CreateAccessToken(a storage.AccessToken) error { + _, err := c.Exec(` + insert into access_token ( + id, connector_data, connector_id, expiry + ) + values ($1, $2, $3, $4); + `, + a.ID, a.ConnectorData, a.ConnectorID, a.Expiry, + ) + + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert access token: %v", err) + } + return nil +} + +func (c *conn) GetAccessToken(id string) (a storage.AccessToken, err error) { + err = c.QueryRow(` + select + id, connector_data, connector_id, expiry + from access_token where id = $1; + `, id).Scan( + &a.ID, &a.ConnectorData, &a.ConnectorID, &a.Expiry, + ) + if err != nil { + if err == sql.ErrNoRows { + return a, storage.ErrNotFound + } + return a, fmt.Errorf("select access token: %v", err) + } + return a, nil +} + func (c *conn) CreateRefresh(r storage.RefreshToken) error { _, err := c.Exec(` insert into refresh_token ( @@ -817,6 +861,7 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) { func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) } func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) } +func (c *conn) DeleteAccessToken(id string) error { return c.delete("access_token", "id", id) } func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) } func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", "id", id) } func (c *conn) DeletePassword(email string) error { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 6341037a41..08cc0794f9 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -121,6 +121,13 @@ var migrations = []migration{ expiry timestamptz not null ); + + create table access_token ( + id text not null primary key, + connector_id text not null, + connector_data blob not null, + expiry timestamptz not null + ); create table refresh_token ( id text not null primary key, diff --git a/storage/storage.go b/storage/storage.go index 893fb10035..197f49e9da 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -38,6 +38,7 @@ func NewID() string { type GCResult struct { AuthRequests int64 AuthCodes int64 + AccessTokens int64 } // Storage is the storage interface used by the server. Implementations are @@ -50,6 +51,7 @@ type Storage interface { CreateAuthRequest(a AuthRequest) error CreateClient(c Client) error CreateAuthCode(c AuthCode) error + CreateAccessToken(c AccessToken) error CreateRefresh(r RefreshToken) error CreatePassword(p Password) error CreateOfflineSessions(s OfflineSessions) error @@ -59,6 +61,7 @@ type Storage interface { // requests that way instead of using ErrNotFound. GetAuthRequest(id string) (AuthRequest, error) GetAuthCode(id string) (AuthCode, error) + GetAccessToken(id string) (AccessToken, error) GetClient(id string) (Client, error) GetKeys() (Keys, error) GetRefresh(id string) (RefreshToken, error) @@ -74,6 +77,7 @@ type Storage interface { // Delete methods MUST be atomic. DeleteAuthRequest(id string) error DeleteAuthCode(code string) error + DeleteAccessToken(id string) error DeleteClient(id string) error DeleteRefresh(id string) error DeletePassword(email string) error @@ -219,6 +223,22 @@ type AuthCode struct { Expiry time.Time } +// AccessToken represents all information associated with an access token +// which is sent to a client that authenticated with dex +// This value is stored once once Dex is authenticated with an upstream source +type AccessToken struct { + // Random fake Token sent to the client as if it were the real AccessToken + ID string + + // Authorization data provided by an upstream source. + ConnectorData []byte + + // The connector this access token is valid for + ConnectorID string + + Expiry time.Time +} + // RefreshToken is an OAuth2 refresh token which allows a client to request new // tokens on the end user's behalf. type RefreshToken struct {