From 1e3e24f5723b94dc4243e7a9014eb3e3ebf5c079 Mon Sep 17 00:00:00 2001 From: Ishank Arora Date: Mon, 15 Nov 2021 14:20:00 +0100 Subject: [PATCH] OIDC driver changes for lightweight users --- changelog/unreleased/oidc-lw-users.md | 3 ++ pkg/auth/manager/oidc/oidc.go | 27 ++++++++++++++---- pkg/cbox/user/rest/rest.go | 40 +++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 changelog/unreleased/oidc-lw-users.md diff --git a/changelog/unreleased/oidc-lw-users.md b/changelog/unreleased/oidc-lw-users.md new file mode 100644 index 00000000000..2053f695513 --- /dev/null +++ b/changelog/unreleased/oidc-lw-users.md @@ -0,0 +1,3 @@ +Enhancement: OIDC driver changes for lightweight users + +https://github.com/cs3org/reva/pull/2278 \ No newline at end of file diff --git a/pkg/auth/manager/oidc/oidc.go b/pkg/auth/manager/oidc/oidc.go index c6dd62b371a..ef0ae2ccdb3 100644 --- a/pkg/auth/manager/oidc/oidc.go +++ b/pkg/auth/manager/oidc/oidc.go @@ -23,6 +23,7 @@ package oidc import ( "context" "fmt" + "strings" "time" oidc "github.com/coreos/go-oidc" @@ -130,15 +131,17 @@ func (am *mgr) Authenticate(ctx context.Context, clientID, clientSecret string) if claims["email_verified"] == nil { // This is not set in simplesamlphp claims["email_verified"] = false } + if claims["preferred_username"] == nil { + claims["preferred_username"] = claims[am.c.IDClaim] + } + if claims["name"] == nil { + claims["name"] = claims[am.c.IDClaim] + } if claims["email"] == nil { return nil, nil, fmt.Errorf("no \"email\" attribute found in userinfo: maybe the client did not request the oidc \"email\"-scope") } - if claims["preferred_username"] == nil || claims["name"] == nil { - return nil, nil, fmt.Errorf("no \"preferred_username\" or \"name\" attribute found in userinfo: maybe the client did not request the oidc \"profile\"-scope") - } - var uid, gid float64 if am.c.UIDClaim != "" { uid, _ = claims[am.c.UIDClaim].(float64) @@ -150,7 +153,7 @@ func (am *mgr) Authenticate(ctx context.Context, clientID, clientSecret string) userID := &user.UserId{ OpaqueId: claims[am.c.IDClaim].(string), // a stable non reassignable id Idp: claims["issuer"].(string), // in the scope of this issuer - Type: user.UserType_USER_TYPE_PRIMARY, + Type: getUserType(claims[am.c.IDClaim].(string)), } gwc, err := pool.GetGatewayServiceClient(am.c.GatewaySvc) if err != nil { @@ -228,3 +231,17 @@ func (am *mgr) getOIDCProvider(ctx context.Context) (*oidc.Provider, error) { am.provider = provider return am.provider, nil } + +func getUserType(upn string) user.UserType { + var t user.UserType + switch { + case strings.HasPrefix(upn, "guest"): + t = user.UserType_USER_TYPE_LIGHTWEIGHT + case strings.Contains(upn, "@"): + t = user.UserType_USER_TYPE_FEDERATED + default: + t = user.UserType_USER_TYPE_PRIMARY + } + return t + +} diff --git a/pkg/cbox/user/rest/rest.go b/pkg/cbox/user/rest/rest.go index eeedaba59a7..541e11d35bb 100644 --- a/pkg/cbox/user/rest/rest.go +++ b/pkg/cbox/user/rest/rest.go @@ -128,9 +128,7 @@ func (m *manager) Configure(ml map[string]interface{}) error { return nil } -func (m *manager) getUserByParam(ctx context.Context, param, val string) (map[string]interface{}, error) { - url := fmt.Sprintf("%s/Identity?filter=%s:%s&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", - m.conf.APIBaseURL, param, url.QueryEscape(val)) +func (m *manager) getUser(ctx context.Context, url string) (map[string]interface{}, error) { responseData, err := m.apiTokenManager.SendAPIGetRequest(ctx, url, false) if err != nil { return nil, err @@ -151,17 +149,38 @@ func (m *manager) getUserByParam(ctx context.Context, param, val string) (map[st } if len(users) != 1 { - return nil, errors.New("rest: user not found: " + param + ": " + val) + return nil, errors.New("rest: user not found for URL: " + url) } return users[0], nil } +func (m *manager) getUserByParam(ctx context.Context, param, val string) (map[string]interface{}, error) { + url := fmt.Sprintf("%s/Identity?filter=%s:%s&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", + m.conf.APIBaseURL, param, url.QueryEscape(val)) + return m.getUser(ctx, url) +} + +func (m *manager) getLightweightUser(ctx context.Context, mail string) (map[string]interface{}, error) { + url := fmt.Sprintf("%s/Identity?filter=primaryAccountEmail:%s&filter=upn:contains:guest&field=upn&field=primaryAccountEmail&field=displayName&field=uid&field=gid&field=type", + m.conf.APIBaseURL, url.QueryEscape(mail)) + return m.getUser(ctx, url) +} + func (m *manager) getInternalUserID(ctx context.Context, uid *userpb.UserId) (string, error) { internalID, err := m.fetchCachedInternalID(uid) if err != nil { - userData, err := m.getUserByParam(ctx, "upn", uid.OpaqueId) + var ( + userData map[string]interface{} + err error + ) + if uid.Type == userpb.UserType_USER_TYPE_LIGHTWEIGHT { + // Lightweight accounts need to be fetched by email + userData, err = m.getLightweightUser(ctx, strings.TrimPrefix(uid.OpaqueId, "guest:")) + } else { + userData, err = m.getUserByParam(ctx, "upn", uid.OpaqueId) + } if err != nil { return "", err } @@ -217,7 +236,16 @@ func (m *manager) parseAndCacheUser(ctx context.Context, userData map[string]int func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User, error) { u, err := m.fetchCachedUserDetails(uid) if err != nil { - userData, err := m.getUserByParam(ctx, "upn", uid.OpaqueId) + var ( + userData map[string]interface{} + err error + ) + if uid.Type == userpb.UserType_USER_TYPE_LIGHTWEIGHT { + // Lightweight accounts need to be fetched by email + userData, err = m.getLightweightUser(ctx, strings.TrimPrefix(uid.OpaqueId, "guest:")) + } else { + userData, err = m.getUserByParam(ctx, "upn", uid.OpaqueId) + } if err != nil { return nil, err }