diff --git a/internal/grpc/services/gateway/userprovider.go b/internal/grpc/services/gateway/userprovider.go index 219facabf8..676c950f8d 100644 --- a/internal/grpc/services/gateway/userprovider.go +++ b/internal/grpc/services/gateway/userprovider.go @@ -43,6 +43,22 @@ func (s *svc) GetUser(ctx context.Context, req *user.GetUserRequest) (*user.GetU return res, nil } +func (s *svc) GetUserByClaim(ctx context.Context, req *user.GetUserByClaimRequest) (*user.GetUserByClaimResponse, error) { + c, err := pool.GetUserProviderServiceClient(s.c.UserProviderEndpoint) + if err != nil { + return &user.GetUserGetUserByClaimResponse{ + Status: status.NewInternal(ctx, err, "error getting auth client"), + }, nil + } + + res, err := c.GetUserByClaim(ctx, req) + if err != nil { + return nil, errors.Wrap(err, "gateway: error calling GetUserByClaim") + } + + return res, nil +} + func (s *svc) FindUsers(ctx context.Context, req *user.FindUsersRequest) (*user.FindUsersResponse, error) { c, err := pool.GetUserProviderServiceClient(s.c.UserProviderEndpoint) if err != nil { diff --git a/internal/grpc/services/userprovider/userprovider.go b/internal/grpc/services/userprovider/userprovider.go index 120c6b04b8..16055d6057 100644 --- a/internal/grpc/services/userprovider/userprovider.go +++ b/internal/grpc/services/userprovider/userprovider.go @@ -23,7 +23,6 @@ import ( "fmt" userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" - types "github.com/cs3org/go-cs3apis/cs3/types/v1beta1" "github.com/cs3org/reva/pkg/rgrpc" "github.com/cs3org/reva/pkg/rgrpc/status" "github.com/cs3org/reva/pkg/user" @@ -100,36 +99,37 @@ func (s *service) Register(ss *grpc.Server) { } func (s *service) GetUser(ctx context.Context, req *userpb.GetUserRequest) (*userpb.GetUserResponse, error) { - - // Check if we need to retrieve user from UID or not - uid, err := extractUID(req.Opaque) - - var usr *userpb.User - if err == nil { - usr, err = s.usermgr.GetUserByUID(ctx, uid) - if err != nil { - // TODO(labkode): check for not found. - err = errors.Wrap(err, "userprovidersvc: error getting user") - res := &userpb.GetUserResponse{ - Status: status.NewInternal(ctx, err, "error authenticating user"), - } - return res, nil - } - } else { - usr, err = s.usermgr.GetUser(ctx, req.UserId) - if err != nil { - // TODO(labkode): check for not found. - err = errors.Wrap(err, "userprovidersvc: error getting user") - res := &userpb.GetUserResponse{ - Status: status.NewInternal(ctx, err, "error authenticating user"), - } - return res, nil + user, err := s.usermgr.GetUser(ctx, req.UserId) + if err != nil { + // TODO(labkode): check for not found. + err = errors.Wrap(err, "userprovidersvc: error getting user") + res := &userpb.GetUserResponse{ + Status: status.NewInternal(ctx, err, "error getting user"), } + return res, nil } res := &userpb.GetUserResponse{ Status: status.NewOK(ctx), - User: usr, + User: user, + } + return res, nil +} + +func (s *service) GetUserByClaim(ctx context.Context, req *userpb.GetUserByClaimRequest) (*userpb.GetUserByClaimResponse, error) { + user, err := s.usermgr.GetUserByClaim(ctx, req.Claim, req.Value) + if err != nil { + // TODO(labkode): check for not found. + err = errors.Wrap(err, "userprovidersvc: error getting user by claim") + res := &userpb.GetUserByClaimResponse{ + Status: status.NewInternal(ctx, err, "error getting user by claim"), + } + return res, nil + } + + res := &userpb.GetUserByClaim{ + Status: status.NewOK(ctx), + User: user, } return res, nil } @@ -185,14 +185,3 @@ func (s *service) IsInGroup(ctx context.Context, req *userpb.IsInGroupRequest) ( return res, nil } - -func extractUID(opaqueObj *types.Opaque) (string, error) { - if opaqueObj != nil && opaqueObj.Map != nil { - if uidObj, ok := opaqueObj.Map["uid"]; ok { - if uidObj.Decoder == "plain" { - return string(uidObj.Value), nil - } - } - } - return "", errors.New("could not retrieve UID from opaque object") -} diff --git a/pkg/user/manager/demo/demo.go b/pkg/user/manager/demo/demo.go index 5d1270db5e..de61836dc2 100644 --- a/pkg/user/manager/demo/demo.go +++ b/pkg/user/manager/demo/demo.go @@ -51,13 +51,31 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User return nil, errtypes.NotFound(uid.OpaqueId) } -func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) { +func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) { for _, u := range m.catalog { - if userUID, err := extractUID(u); err == nil && uid == userUID { + if userClaim, err := extractClaim(u, claim); err == nil && value == userClaim { return u, nil } } - return nil, errtypes.NotFound(uid) + return nil, errtypes.NotFound(value) +} + +func extractClaim(u *userpb.User, claim string) (string, error) { + switch claim { + case "mail": + return u.Mail, nil + case "username": + return u.Username, nil + case "uid": + if u.Opaque != nil && u.Opaque.Map != nil { + if uidObj, ok := u.Opaque.Map["uid"]; ok { + if uidObj.Decoder == "plain" { + return string(uidObj.Value), nil + } + } + } + } + return "", errors.New("demo: invalid field") } // TODO(jfd) search Opaque? compare sub? @@ -97,17 +115,6 @@ func (m *manager) IsInGroup(ctx context.Context, uid *userpb.UserId, group strin return false, nil } -func extractUID(u *userpb.User) (string, error) { - if u.Opaque != nil && u.Opaque.Map != nil { - if uidObj, ok := u.Opaque.Map["uid"]; ok { - if uidObj.Decoder == "plain" { - return string(uidObj.Value), nil - } - } - } - return "", errors.New("demo: could not retrieve UID from user") -} - func getUsers() map[string]*userpb.User { return map[string]*userpb.User{ "4c510ada-c86b-4815-8820-42cdf82c3d51": &userpb.User{ diff --git a/pkg/user/manager/demo/demo_test.go b/pkg/user/manager/demo/demo_test.go index faafce0391..f94963109c 100644 --- a/pkg/user/manager/demo/demo_test.go +++ b/pkg/user/manager/demo/demo_test.go @@ -65,19 +65,25 @@ func TestUserManager(t *testing.T) { t.Fatalf("user not found error differs: expected='%v' got='%v'", expectedErr, err) } - // positive test GetUserByUID - resUserByUID, _ := manager.GetUserByUID(ctx, "123") + // positive test GetUserByClaim by uid + resUserByUID, _ := manager.GetUserByClaim(ctx, "uid", "123") if !reflect.DeepEqual(resUserByUID, userEinstein) { t.Fatalf("user differs: expected=%v got=%v", userEinstein, resUserByUID) } - // negative test GetUserByUID + // negative test GetUserByClaim by uid expectedErr = errtypes.NotFound("789") - _, err = manager.GetUserByUID(ctx, "789") + _, err = manager.GetUserByClaim(ctx, "uid", "789") if !reflect.DeepEqual(err, expectedErr) { t.Fatalf("user not found error differs: expected='%v' got='%v'", expectedErr, err) } + // positive test GetUserByClaim by mail + resUserByEmail, _ := manager.GetUserByClaim(ctx, "mail", "einstein@example.org") + if !reflect.DeepEqual(resUserByEmail, userEinstein) { + t.Fatalf("user differs: expected=%v got=%v", userEinstein, resUserByEmail) + } + // test FindUsers resUser, _ := manager.FindUsers(ctx, "einstein") if !reflect.DeepEqual(resUser, []*userpb.User{userEinstein}) { diff --git a/pkg/user/manager/json/json.go b/pkg/user/manager/json/json.go index ffe4fe7220..7202082e0a 100644 --- a/pkg/user/manager/json/json.go +++ b/pkg/user/manager/json/json.go @@ -95,8 +95,31 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User return nil, errtypes.NotFound(uid.OpaqueId) } -func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) { - return nil, errtypes.NotSupported("json: looking up user by UID not supported") +func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) { + for _, u := range m.users { + if userClaim, err := extractClaim(u, claim); err == nil && value == userClaim { + return u, nil + } + } + return nil, errtypes.NotFound(value) +} + +func extractClaim(u *userpb.User, claim string) (string, error) { + switch claim { + case "mail": + return u.Mail, nil + case "username": + return u.Username, nil + case "uid": + if u.Opaque != nil && u.Opaque.Map != nil { + if uidObj, ok := u.Opaque.Map["uid"]; ok { + if uidObj.Decoder == "plain" { + return string(uidObj.Value), nil + } + } + } + } + return "", errors.New("json: invalid field") } // TODO(jfd) search Opaque? compare sub? diff --git a/pkg/user/manager/json/json_test.go b/pkg/user/manager/json/json_test.go index 0eabc2fb04..52ced0df2a 100644 --- a/pkg/user/manager/json/json_test.go +++ b/pkg/user/manager/json/json_test.go @@ -89,12 +89,19 @@ func TestUserManager(t *testing.T) { manager, _ := New(input) // setup test data - userEinstein := &userpb.UserId{Idp: "localhost", OpaqueId: "einstein"} + uidEinstein := &userpb.UserId{Idp: "localhost", OpaqueId: "einstein"} + userEinstein := &userpb.User{ + Id: uidEinstein, + Username: "einstein", + Groups: []string{"sailing-lovers", "violin-haters", "physics-lovers"}, + Mail: "einstein@example.org", + DisplayName: "Albert Einstein", + } userFake := &userpb.UserId{Idp: "localhost", OpaqueId: "fakeUser"} groupsEinstein := []string{"sailing-lovers", "violin-haters", "physics-lovers"} // positive test GetUserGroups - resGroups, _ := manager.GetUserGroups(ctx, userEinstein) + resGroups, _ := manager.GetUserGroups(ctx, uidEinstein) if !reflect.DeepEqual(resGroups, groupsEinstein) { t.Fatalf("groups differ: expected=%v got=%v", resGroups, groupsEinstein) } @@ -106,6 +113,19 @@ func TestUserManager(t *testing.T) { t.Fatalf("user not found error differ: expected='%v' got='%v'", expectedErr, err) } + // positive test GetUserByClaim by mail + resUserByEmail, _ := manager.GetUserByClaim(ctx, "mail", "einstein@example.org") + if !reflect.DeepEqual(resUserByEmail, userEinstein) { + t.Fatalf("user differs: expected=%v got=%v", userEinstein, resUserByEmail) + } + + // negative test GetUserByClaim by mail + expectedErr = errtypes.NotFound("abc@example.com") + _, err = manager.GetUserByClaim(ctx, "mail", "abc@example.com") + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("user not found error differs: expected='%v' got='%v'", expectedErr, err) + } + // test FindUsers resUser, _ := manager.FindUsers(ctx, "stein") if len(resUser) != 1 { @@ -116,13 +136,13 @@ func TestUserManager(t *testing.T) { } // positive test IsInGroup - resInGroup, _ := manager.IsInGroup(ctx, userEinstein, "physics-lovers") + resInGroup, _ := manager.IsInGroup(ctx, uidEinstein, "physics-lovers") if !resInGroup { t.Fatalf("user not in group: expected=%v got=%v", true, false) } // negative test IsInGroup with wrong group - resInGroup, _ = manager.IsInGroup(ctx, userEinstein, "notARealGroup") + resInGroup, _ = manager.IsInGroup(ctx, uidEinstein, "notARealGroup") if resInGroup { t.Fatalf("user not in group: expected=%v got=%v", true, false) } diff --git a/pkg/user/manager/ldap/ldap.go b/pkg/user/manager/ldap/ldap.go index c14ce8b554..89473c8b27 100644 --- a/pkg/user/manager/ldap/ldap.go +++ b/pkg/user/manager/ldap/ldap.go @@ -179,8 +179,8 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User return u, nil } -func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) { - return nil, errtypes.NotSupported("ldap: looking up user by UID not supported") +func (m *manager) GetUserByClaim(ctx context.Context, field, claim string) (*userpb.User, error) { + return nil, errtypes.NotSupported("ldap: looking up user by specific field not supported") } func (m *manager) FindUsers(ctx context.Context, query string) ([]*userpb.User, error) { diff --git a/pkg/user/manager/rest/cache.go b/pkg/user/manager/rest/cache.go index 7bd5adc511..5b27064625 100644 --- a/pkg/user/manager/rest/cache.go +++ b/pkg/user/manager/rest/cache.go @@ -31,7 +31,6 @@ const ( userDetailsPrefix = "user:" userGroupsPrefix = "groups:" userInternalIDPrefix = "internal:" - userUIDPrefix = "uid:" ) func initRedisPool(address, username, password string) *redis.Pool { @@ -136,14 +135,21 @@ func (m *manager) cacheUserDetails(u *userpb.User) error { if err != nil { return err } - if err = m.setVal(userUIDPrefix+uid, u.Id.OpaqueId, -1); err != nil { + + if err = m.setVal("uid:"+uid, u.Id.OpaqueId, -1); err != nil { + return err + } + if err = m.setVal("mail:"+u.Mail, u.Id.OpaqueId, -1); err != nil { + return err + } + if err = m.setVal("username:"+u.Username, u.Id.OpaqueId, -1); err != nil { return err } return nil } -func (m *manager) fetchCachedUID(uid string) (string, error) { - return m.getVal(userUIDPrefix + uid) +func (m *manager) fetchCachedParam(field, claim string) (string, error) { + return m.getVal(field + ":" + claim) } func (m *manager) fetchCachedUserGroups(uid *userpb.UserId) ([]string, error) { diff --git a/pkg/user/manager/rest/rest.go b/pkg/user/manager/rest/rest.go index 1c1b8c3290..7b9f92e5b3 100644 --- a/pkg/user/manager/rest/rest.go +++ b/pkg/user/manager/rest/rest.go @@ -320,13 +320,24 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User return u, nil } -func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) { - opaqueID, err := m.fetchCachedUID(uid) +func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) { + opaqueID, err := m.fetchCachedParam(claim, value) if err == nil { return m.GetUser(ctx, &userpb.UserId{OpaqueId: opaqueID}) } - userData, err := m.getUserByParam(ctx, "uid", uid) + switch claim { + case "mail": + claim = "primaryAccountEmail" + case "uid": + claim = "uid" + case "username": + claim = "upn" + default: + return nil, errors.New("rest: invalid field") + } + + userData, err := m.getUserByParam(ctx, claim, value) if err != nil { return nil, err } diff --git a/pkg/user/user.go b/pkg/user/user.go index 7f3409f0b3..dd22062f82 100644 --- a/pkg/user/user.go +++ b/pkg/user/user.go @@ -65,7 +65,7 @@ func ContextSetUserID(ctx context.Context, id *userpb.UserId) context.Context { // Manager is the interface to implement to manipulate users. type Manager interface { GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User, error) - GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) + GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]string, error) IsInGroup(ctx context.Context, uid *userpb.UserId, group string) (bool, error) FindUsers(ctx context.Context, query string) ([]*userpb.User, error)