Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added User Info endpoint #1107

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Documentation/oidc-connector.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ connectors:
# following field.
#
# basicAuthUnsupported: true

# Some clients require the possibility to get further user details provided
# by via the UserInfo endpoint.
#
# See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfo
# userInfo: https://www.googleapis.com/oauth2/v3/userinfo

# Google supports whitelisting allowed domains when using G Suite
# (Google Apps). The following field can be set to a list of domains
Expand Down
5 changes: 5 additions & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ type CallbackConnector interface {
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
}

// UserInfoConnector represents connectors that support the user info endpoint
type UserInfoConnector interface {
GetUserInfo(connData []byte) (user map[string]interface{}, err error)
}

// SAMLConnector represents SAML connectors which implement the HTTP POST binding.
// RelayState is handled by the server.
//
Expand Down
71 changes: 71 additions & 0 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"strings"
"sync"
"encoding/json"


"github.com/coreos/go-oidc"
"github.com/sirupsen/logrus"
Expand All @@ -17,6 +19,17 @@ import (
"github.com/coreos/dex/connector"
)

type connectorData struct {
// Store token, even if it eventually expires
AccessToken string `json:"accessToken"`
TokenType string `json:"tokenType"`
}

type connectorData2 struct {
// Store token, even if it eventually expires
AccessToken string `json:"accessToken"`
}

// Config holds configuration options for OpenID Connect logins.
type Config struct {
Issuer string `json:"issuer"`
Expand All @@ -36,6 +49,8 @@ type Config struct {
// Optional list of whitelisted domains when using Google
// If this field is nonempty, only users from a listed domain will be allowed to log in
HostedDomains []string `json:"hostedDomains"`
// Userinfo endpoint for Relying Parties that need details about the authenticated user
UserInfo string `json:"userInfo"`
}

// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
Expand Down Expand Up @@ -116,6 +131,7 @@ func (c *Config) Open(id string, logger logrus.FieldLogger) (conn connector.Conn
logger: logger,
cancel: cancel,
hostedDomains: c.HostedDomains,
userInfo: c.UserInfo,
}, nil
}

Expand All @@ -132,6 +148,7 @@ type oidcConnector struct {
cancel context.CancelFunc
logger logrus.FieldLogger
hostedDomains []string
userInfo string
}

func (c *oidcConnector) Close() error {
Expand Down Expand Up @@ -215,9 +232,63 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
Email: claims.Email,
EmailVerified: claims.EmailVerified,
}

// Add AccessToken to user identity for future requests
data := connectorData{
TokenType: "bearer",
AccessToken: token.AccessToken,
}

connData, err := json.Marshal(data)
if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err)
}

identity.ConnectorData = connData

return identity, nil
}

func (c *oidcConnector) GetUserInfo(connData []byte) (user map[string]interface{}, err error) {
user = map[string]interface{}{}

// Extract Access Token that was stored while the connector handled the OIDC handshake
var cData connectorData

if err := json.Unmarshal(connData, &cData); err != nil {
c.logger.Errorf("Failed to read connector data : %v", err)
return user, err
}

// Format Authorization header to request user info on behalf of the client
authorization := fmt.Sprintf("%s %s", cData.TokenType, cData.AccessToken)
// Prepare get request including Authorization header from RP
req, err := http.NewRequest("GET", c.userInfo, nil)

if err != nil {
fmt.Errorf("Error Creating GET request: %v", err)
return user, err
}

req.Header.Add("Accept", `application/json`)
req.Header.Add("Content-Type", `application/json`)
req.Header.Add("Authorization", authorization)

// Prepare http client to execute get request
client := &http.Client{}

resp, err := client.Do(req)
if err != nil {
fmt.Errorf("Error Executing GET request: %v", err)
return user, err
}

json.NewDecoder(resp.Body).Decode(&user)


return user, nil
}

// Refresh is implemented for backwards compatibility, even though it's a no-op.
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
return identity, nil
Expand Down
1 change: 1 addition & 0 deletions examples/config-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ connectors:
# name: Google
# config:
# issuer: https://accounts.google.com
# userInfo: https://www.googleapis.com/oauth2/v3/userinfo
# # Connector config values starting with a "$" will read from the environment.
# clientID: $GOOGLE_CLIENT_ID
# clientSecret: $GOOGLE_CLIENT_SECRET
Expand Down
78 changes: 78 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type discovery struct {
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
Expand All @@ -109,6 +110,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Keys: s.absURL("/keys"),
UserInfo: s.absURL("/userinfo"),
Subjects: []string{"public"},
IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
Expand Down Expand Up @@ -656,7 +658,24 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

// Random token to be handed to the client as Access Token
accessToken := storage.NewID()

// Create and store connector information under this access token
// so it can be used for future requests on behalf of the client
newAccessToken := storage.AccessToken{
ID: accessToken,
ConnectorData: authCode.ConnectorData,
ConnectorID: authCode.ConnectorID,
Expiry: s.now().Add(time.Minute * 30),
}

if err := s.storage.CreateAccessToken(newAccessToken); err != nil {
s.logger.Errorf("Failed to create access token: %v", err)
s.renderError(w, http.StatusInternalServerError, "Internal server error.")
return
}

idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
Expand Down Expand Up @@ -994,6 +1013,65 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r
w.Write(data)
}

// Request userinfo data from the original IdP on behalf of the Relying Party (client)
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {

// Extract token type and token value from Authorization header
authorization_substrs := strings.Split(r.Header.Get("Authorization"), " ")
clientAccessToken := authorization_substrs[1]

// Retrieve previously saved actual Access Token to allow requesting user info from upstream source
accessToken, err := s.storage.GetAccessToken(clientAccessToken)
if err != nil || s.now().After(accessToken.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired access token.", http.StatusBadRequest)
}
return
}

// Delegate the fetching of User info data to the oidc connector
conn, err := s.getConnector(accessToken.ConnectorID)

if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", accessToken.ConnectorID, err)
s.renderError(w, http.StatusBadRequest, "Requested resource does not exist")
return
}

var user map[string]interface{}

switch conn := conn.Connector.(type) {
case connector.UserInfoConnector:

lala := map[string]interface{}{}
if err := json.Unmarshal(accessToken.ConnectorData, &lala); err != nil {
s.logger.Errorf("Failed to read lala data : %v", err)
}

user, err = conn.GetUserInfo(accessToken.ConnectorData)
default:
s.renderError(w, http.StatusBadRequest, "Requested resource does not exist.")
}

s.writeUserInfo(w, user)
}

func (s *Server) writeUserInfo(w http.ResponseWriter, user map[string]interface{}) {

data, err := json.Marshal(user)
if err != nil {
s.logger.Errorf("failed to marshal access token response: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}

func (s *Server) renderError(w http.ResponseWriter, status int, description string) {
if err := s.templates.err(w, http.StatusText(status), description); err != nil {
s.logger.Errorf("Server template error: %v", err)
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
// TODO(ericchiang): rate limit certain paths based on IP.
handleWithCORS("/token", s.handleToken)
handleWithCORS("/keys", s.handlePublicKeys)
handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
Expand Down
93 changes: 91 additions & 2 deletions storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest)
func RunTests(t *testing.T, newStorage func() storage.Storage) {
runTests(t, newStorage, []subTest{
{"AuthCodeCRUD", testAuthCodeCRUD},
{"AccessTokenCRUD", testAccessTokenCRUD},
{"AuthRequestCRUD", testAuthRequestCRUD},
{"ClientCRUD", testClientCRUD},
{"RefreshTokenCRUD", testRefreshTokenCRUD},
Expand Down Expand Up @@ -238,6 +239,57 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
mustBeErrNotFound(t, "auth code", err)
}

func testAccessTokenCRUD(t *testing.T, s storage.Storage) {
a1 := storage.AccessToken{
ID: storage.NewID(),
ConnectorData: []byte(`{"some":"data"}`),
ConnectorID: "ldap",
Expiry: neverExpire,
}

if err := s.CreateAccessToken(a1); err != nil {
t.Fatalf("failed creating access token: %v", err)
}

a2 := storage.AccessToken{
ID: storage.NewID(),
ConnectorData: []byte(`{"some":"data"}`),
ConnectorID: "ldap",
Expiry: neverExpire,
}

// Attempt to create same AccessToken twice.
err := s.CreateAccessToken(a1)
mustBeErrAlreadyExists(t, "access token", err)

if err := s.CreateAccessToken(a2); err != nil {
t.Fatalf("failed creating access token: %v", err)
}

got, err := s.GetAccessToken(a1.ID)
if err != nil {
t.Fatalf("failed to get access token: %v", err)
}
if a1.Expiry.Unix() != got.Expiry.Unix() {
t.Errorf("access token expiry did not match want=%s vs got=%s", a1.Expiry, got.Expiry)
}
got.Expiry = a1.Expiry // time fields do not compare well
if diff := pretty.Compare(a1, got); diff != "" {
t.Errorf("access token retrieved from storage did not match: %s", diff)
}

if err := s.DeleteAccessToken(a1.ID); err != nil {
t.Fatalf("delete access token: %v", err)
}

if err := s.DeleteAccessToken(a2.ID); err != nil {
t.Fatalf("delete access token: %v", err)
}

_, err = s.GetAccessToken(a1.ID)
mustBeErrNotFound(t, "access token", err)
}

func testClientCRUD(t *testing.T, s storage.Storage) {
id1 := storage.NewID()
c1 := storage.Client{
Expand Down Expand Up @@ -731,7 +783,7 @@ func testGC(t *testing.T, s storage.Storage) {
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.AuthCodes != 0 || result.AuthRequests != 0 {
if result.AuthCodes != 0 || result.AccessTokens != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
}
Expand All @@ -752,6 +804,43 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}

b := storage.AccessToken{
ID: storage.NewID(),
ConnectorData: []byte(`{"some":"data"}`),
ConnectorID: "ldap",
Expiry: expiry,
}

if err := s.CreateAccessToken(b); err != nil {
t.Fatalf("failed creating access token: %v", err)
}

for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.AccessTokens != 0 || result.AccessTokens != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
}
if _, err := s.GetAccessToken(b.ID); err != nil {
t.Errorf("expected to be able to get access token after GC: %v", err)
}
}

if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AccessTokens != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AccessTokens)
}

if _, err := s.GetAccessToken(b.ID); err == nil {
t.Errorf("expected access token to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}

a := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "foobar",
Expand Down Expand Up @@ -783,7 +872,7 @@ func testGC(t *testing.T, s storage.Storage) {
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.AuthCodes != 0 || result.AuthRequests != 0 {
if result.AuthCodes != 0 || result.AccessTokens != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
}
Expand Down
Loading