Skip to content

Commit

Permalink
Add logic for GetUserByClaim
Browse files Browse the repository at this point in the history
  • Loading branch information
ishank011 committed Jul 28, 2020
1 parent eceb762 commit 9031f6e
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 71 deletions.
16 changes: 16 additions & 0 deletions internal/grpc/services/gateway/userprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@ func (s *svc) GetUser(ctx context.Context, req *user.GetUserRequest) (*user.GetU
return res, nil
}

func (s *svc) GetUserByClaim(ctx context.Context, req *user.GetUserByClaimRequest) (*user.GetUserByClaimResponse, error) {
c, err := pool.GetUserProviderServiceClient(s.c.UserProviderEndpoint)
if err != nil {
return &user.GetUserGetUserByClaimResponse{
Status: status.NewInternal(ctx, err, "error getting auth client"),
}, nil
}

res, err := c.GetUserByClaim(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "gateway: error calling GetUserByClaim")
}

return res, nil
}

func (s *svc) FindUsers(ctx context.Context, req *user.FindUsersRequest) (*user.FindUsersResponse, error) {
c, err := pool.GetUserProviderServiceClient(s.c.UserProviderEndpoint)
if err != nil {
Expand Down
63 changes: 26 additions & 37 deletions internal/grpc/services/userprovider/userprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"fmt"

userpb "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
types "github.com/cs3org/go-cs3apis/cs3/types/v1beta1"
"github.com/cs3org/reva/pkg/rgrpc"
"github.com/cs3org/reva/pkg/rgrpc/status"
"github.com/cs3org/reva/pkg/user"
Expand Down Expand Up @@ -100,36 +99,37 @@ func (s *service) Register(ss *grpc.Server) {
}

func (s *service) GetUser(ctx context.Context, req *userpb.GetUserRequest) (*userpb.GetUserResponse, error) {

// Check if we need to retrieve user from UID or not
uid, err := extractUID(req.Opaque)

var usr *userpb.User
if err == nil {
usr, err = s.usermgr.GetUserByUID(ctx, uid)
if err != nil {
// TODO(labkode): check for not found.
err = errors.Wrap(err, "userprovidersvc: error getting user")
res := &userpb.GetUserResponse{
Status: status.NewInternal(ctx, err, "error authenticating user"),
}
return res, nil
}
} else {
usr, err = s.usermgr.GetUser(ctx, req.UserId)
if err != nil {
// TODO(labkode): check for not found.
err = errors.Wrap(err, "userprovidersvc: error getting user")
res := &userpb.GetUserResponse{
Status: status.NewInternal(ctx, err, "error authenticating user"),
}
return res, nil
user, err := s.usermgr.GetUser(ctx, req.UserId)
if err != nil {
// TODO(labkode): check for not found.
err = errors.Wrap(err, "userprovidersvc: error getting user")
res := &userpb.GetUserResponse{
Status: status.NewInternal(ctx, err, "error getting user"),
}
return res, nil
}

res := &userpb.GetUserResponse{
Status: status.NewOK(ctx),
User: usr,
User: user,
}
return res, nil
}

func (s *service) GetUserByClaim(ctx context.Context, req *userpb.GetUserByClaimRequest) (*userpb.GetUserByClaimResponse, error) {
user, err := s.usermgr.GetUserByClaim(ctx, req.Claim, req.Value)
if err != nil {
// TODO(labkode): check for not found.
err = errors.Wrap(err, "userprovidersvc: error getting user by claim")
res := &userpb.GetUserByClaimResponse{
Status: status.NewInternal(ctx, err, "error getting user by claim"),
}
return res, nil
}

res := &userpb.GetUserByClaim{
Status: status.NewOK(ctx),
User: user,
}
return res, nil
}
Expand Down Expand Up @@ -185,14 +185,3 @@ func (s *service) IsInGroup(ctx context.Context, req *userpb.IsInGroupRequest) (

return res, nil
}

func extractUID(opaqueObj *types.Opaque) (string, error) {
if opaqueObj != nil && opaqueObj.Map != nil {
if uidObj, ok := opaqueObj.Map["uid"]; ok {
if uidObj.Decoder == "plain" {
return string(uidObj.Value), nil
}
}
}
return "", errors.New("could not retrieve UID from opaque object")
}
35 changes: 21 additions & 14 deletions pkg/user/manager/demo/demo.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,31 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User
return nil, errtypes.NotFound(uid.OpaqueId)
}

func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) {
func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) {
for _, u := range m.catalog {
if userUID, err := extractUID(u); err == nil && uid == userUID {
if userClaim, err := extractClaim(u, claim); err == nil && value == userClaim {
return u, nil
}
}
return nil, errtypes.NotFound(uid)
return nil, errtypes.NotFound(value)
}

func extractClaim(u *userpb.User, claim string) (string, error) {
switch claim {
case "mail":
return u.Mail, nil
case "username":
return u.Username, nil
case "uid":
if u.Opaque != nil && u.Opaque.Map != nil {
if uidObj, ok := u.Opaque.Map["uid"]; ok {
if uidObj.Decoder == "plain" {
return string(uidObj.Value), nil
}
}
}
}
return "", errors.New("demo: invalid field")
}

// TODO(jfd) search Opaque? compare sub?
Expand Down Expand Up @@ -97,17 +115,6 @@ func (m *manager) IsInGroup(ctx context.Context, uid *userpb.UserId, group strin
return false, nil
}

func extractUID(u *userpb.User) (string, error) {
if u.Opaque != nil && u.Opaque.Map != nil {
if uidObj, ok := u.Opaque.Map["uid"]; ok {
if uidObj.Decoder == "plain" {
return string(uidObj.Value), nil
}
}
}
return "", errors.New("demo: could not retrieve UID from user")
}

func getUsers() map[string]*userpb.User {
return map[string]*userpb.User{
"4c510ada-c86b-4815-8820-42cdf82c3d51": &userpb.User{
Expand Down
14 changes: 10 additions & 4 deletions pkg/user/manager/demo/demo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,25 @@ func TestUserManager(t *testing.T) {
t.Fatalf("user not found error differs: expected='%v' got='%v'", expectedErr, err)
}

// positive test GetUserByUID
resUserByUID, _ := manager.GetUserByUID(ctx, "123")
// positive test GetUserByClaim by uid
resUserByUID, _ := manager.GetUserByClaim(ctx, "uid", "123")
if !reflect.DeepEqual(resUserByUID, userEinstein) {
t.Fatalf("user differs: expected=%v got=%v", userEinstein, resUserByUID)
}

// negative test GetUserByUID
// negative test GetUserByClaim by uid
expectedErr = errtypes.NotFound("789")
_, err = manager.GetUserByUID(ctx, "789")
_, err = manager.GetUserByClaim(ctx, "uid", "789")
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("user not found error differs: expected='%v' got='%v'", expectedErr, err)
}

// positive test GetUserByClaim by mail
resUserByEmail, _ := manager.GetUserByClaim(ctx, "mail", "einstein@example.org")
if !reflect.DeepEqual(resUserByEmail, userEinstein) {
t.Fatalf("user differs: expected=%v got=%v", userEinstein, resUserByEmail)
}

// test FindUsers
resUser, _ := manager.FindUsers(ctx, "einstein")
if !reflect.DeepEqual(resUser, []*userpb.User{userEinstein}) {
Expand Down
27 changes: 25 additions & 2 deletions pkg/user/manager/json/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,31 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User
return nil, errtypes.NotFound(uid.OpaqueId)
}

func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) {
return nil, errtypes.NotSupported("json: looking up user by UID not supported")
func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) {
for _, u := range m.users {
if userClaim, err := extractClaim(u, claim); err == nil && value == userClaim {
return u, nil
}
}
return nil, errtypes.NotFound(value)
}

func extractClaim(u *userpb.User, claim string) (string, error) {
switch claim {
case "mail":
return u.Mail, nil
case "username":
return u.Username, nil
case "uid":
if u.Opaque != nil && u.Opaque.Map != nil {
if uidObj, ok := u.Opaque.Map["uid"]; ok {
if uidObj.Decoder == "plain" {
return string(uidObj.Value), nil
}
}
}
}
return "", errors.New("json: invalid field")
}

// TODO(jfd) search Opaque? compare sub?
Expand Down
28 changes: 24 additions & 4 deletions pkg/user/manager/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,19 @@ func TestUserManager(t *testing.T) {
manager, _ := New(input)

// setup test data
userEinstein := &userpb.UserId{Idp: "localhost", OpaqueId: "einstein"}
uidEinstein := &userpb.UserId{Idp: "localhost", OpaqueId: "einstein"}
userEinstein := &userpb.User{
Id: uidEinstein,
Username: "einstein",
Groups: []string{"sailing-lovers", "violin-haters", "physics-lovers"},
Mail: "einstein@example.org",
DisplayName: "Albert Einstein",
}
userFake := &userpb.UserId{Idp: "localhost", OpaqueId: "fakeUser"}
groupsEinstein := []string{"sailing-lovers", "violin-haters", "physics-lovers"}

// positive test GetUserGroups
resGroups, _ := manager.GetUserGroups(ctx, userEinstein)
resGroups, _ := manager.GetUserGroups(ctx, uidEinstein)
if !reflect.DeepEqual(resGroups, groupsEinstein) {
t.Fatalf("groups differ: expected=%v got=%v", resGroups, groupsEinstein)
}
Expand All @@ -106,6 +113,19 @@ func TestUserManager(t *testing.T) {
t.Fatalf("user not found error differ: expected='%v' got='%v'", expectedErr, err)
}

// positive test GetUserByClaim by mail
resUserByEmail, _ := manager.GetUserByClaim(ctx, "mail", "einstein@example.org")
if !reflect.DeepEqual(resUserByEmail, userEinstein) {
t.Fatalf("user differs: expected=%v got=%v", userEinstein, resUserByEmail)
}

// negative test GetUserByClaim by mail
expectedErr = errtypes.NotFound("abc@example.com")
_, err = manager.GetUserByClaim(ctx, "mail", "abc@example.com")
if !reflect.DeepEqual(err, expectedErr) {
t.Fatalf("user not found error differs: expected='%v' got='%v'", expectedErr, err)
}

// test FindUsers
resUser, _ := manager.FindUsers(ctx, "stein")
if len(resUser) != 1 {
Expand All @@ -116,13 +136,13 @@ func TestUserManager(t *testing.T) {
}

// positive test IsInGroup
resInGroup, _ := manager.IsInGroup(ctx, userEinstein, "physics-lovers")
resInGroup, _ := manager.IsInGroup(ctx, uidEinstein, "physics-lovers")
if !resInGroup {
t.Fatalf("user not in group: expected=%v got=%v", true, false)
}

// negative test IsInGroup with wrong group
resInGroup, _ = manager.IsInGroup(ctx, userEinstein, "notARealGroup")
resInGroup, _ = manager.IsInGroup(ctx, uidEinstein, "notARealGroup")
if resInGroup {
t.Fatalf("user not in group: expected=%v got=%v", true, false)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/user/manager/ldap/ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User
return u, nil
}

func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) {
return nil, errtypes.NotSupported("ldap: looking up user by UID not supported")
func (m *manager) GetUserByClaim(ctx context.Context, field, claim string) (*userpb.User, error) {
return nil, errtypes.NotSupported("ldap: looking up user by specific field not supported")
}

func (m *manager) FindUsers(ctx context.Context, query string) ([]*userpb.User, error) {
Expand Down
14 changes: 10 additions & 4 deletions pkg/user/manager/rest/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ const (
userDetailsPrefix = "user:"
userGroupsPrefix = "groups:"
userInternalIDPrefix = "internal:"
userUIDPrefix = "uid:"
)

func initRedisPool(address, username, password string) *redis.Pool {
Expand Down Expand Up @@ -136,14 +135,21 @@ func (m *manager) cacheUserDetails(u *userpb.User) error {
if err != nil {
return err
}
if err = m.setVal(userUIDPrefix+uid, u.Id.OpaqueId, -1); err != nil {

if err = m.setVal("uid:"+uid, u.Id.OpaqueId, -1); err != nil {
return err
}
if err = m.setVal("mail:"+u.Mail, u.Id.OpaqueId, -1); err != nil {
return err
}
if err = m.setVal("username:"+u.Username, u.Id.OpaqueId, -1); err != nil {
return err
}
return nil
}

func (m *manager) fetchCachedUID(uid string) (string, error) {
return m.getVal(userUIDPrefix + uid)
func (m *manager) fetchCachedParam(field, claim string) (string, error) {
return m.getVal(field + ":" + claim)
}

func (m *manager) fetchCachedUserGroups(uid *userpb.UserId) ([]string, error) {
Expand Down
17 changes: 14 additions & 3 deletions pkg/user/manager/rest/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,24 @@ func (m *manager) GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User
return u, nil
}

func (m *manager) GetUserByUID(ctx context.Context, uid string) (*userpb.User, error) {
opaqueID, err := m.fetchCachedUID(uid)
func (m *manager) GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error) {
opaqueID, err := m.fetchCachedParam(claim, value)
if err == nil {
return m.GetUser(ctx, &userpb.UserId{OpaqueId: opaqueID})
}

userData, err := m.getUserByParam(ctx, "uid", uid)
switch claim {
case "mail":
claim = "primaryAccountEmail"
case "uid":
claim = "uid"
case "username":
claim = "upn"
default:
return nil, errors.New("rest: invalid field")
}

userData, err := m.getUserByParam(ctx, claim, value)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/user/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func ContextSetUserID(ctx context.Context, id *userpb.UserId) context.Context {
// Manager is the interface to implement to manipulate users.
type Manager interface {
GetUser(ctx context.Context, uid *userpb.UserId) (*userpb.User, error)
GetUserByUID(ctx context.Context, uid string) (*userpb.User, error)
GetUserByClaim(ctx context.Context, claim, value string) (*userpb.User, error)
GetUserGroups(ctx context.Context, uid *userpb.UserId) ([]string, error)
IsInGroup(ctx context.Context, uid *userpb.UserId, group string) (bool, error)
FindUsers(ctx context.Context, query string) ([]*userpb.User, error)
Expand Down

0 comments on commit 9031f6e

Please sign in to comment.