diff --git a/ocis-pkg/oidc/claims.go b/ocis-pkg/oidc/claims.go index b6d4c5016a1..2eadcee2986 100644 --- a/ocis-pkg/oidc/claims.go +++ b/ocis-pkg/oidc/claims.go @@ -1,5 +1,10 @@ package oidc +import ( + "fmt" + "strings" +) + const ( Iss = "iss" Sub = "sub" @@ -12,3 +17,60 @@ const ( OwncloudUUID = "ownclouduuid" OcisRoutingPolicy = "ocis.routing.policy" ) + +// SplitWithEscaping splits s into segments using separator which can be escaped using the escape string +// See https://codereview.stackexchange.com/a/280193 +func SplitWithEscaping(s string, separator string, escapeString string) []string { + a := strings.Split(s, separator) + + for i := len(a) - 2; i >= 0; i-- { + if strings.HasSuffix(a[i], escapeString) { + a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1] + a = append(a[:i+1], a[i+2:]...) + } + } + return a +} + +// WalkSegments uses the given array of segments to walk the claims and return whatever interface was found +func WalkSegments(segments []string, claims map[string]interface{}) (interface{}, error) { + i := 0 + for ; i < len(segments)-1; i++ { + switch castedClaims := claims[segments[i]].(type) { + case map[string]interface{}: + claims = castedClaims + case map[interface{}]interface{}: + claims = make(map[string]interface{}, len(castedClaims)) + for k, v := range castedClaims { + if s, ok := k.(string); ok { + claims[s] = v + } else { + return nil, fmt.Errorf("could not walk claims path, key '%v' is not a string", k) + } + } + default: + return nil, fmt.Errorf("unsupported type '%v'", castedClaims) + } + } + return claims[segments[i]], nil +} + +// ReadStringClaim returns the string obtained by following the . seperated path in the claims +func ReadStringClaim(path string, claims map[string]interface{}) (string, error) { + // check the simple case first + value, _ := claims[path].(string) + if value != "" { + return value, nil + } + + claim, err := WalkSegments(SplitWithEscaping(path, ".", "\\"), claims) + if err != nil { + return "", err + } + + if value, _ = claim.(string); value != "" { + return value, nil + } + + return value, fmt.Errorf("claim path '%s' not set or empty", path) +} diff --git a/ocis-pkg/oidc/claims_test.go b/ocis-pkg/oidc/claims_test.go new file mode 100644 index 00000000000..5a9c60df093 --- /dev/null +++ b/ocis-pkg/oidc/claims_test.go @@ -0,0 +1,182 @@ +package oidc_test + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/owncloud/ocis/v2/ocis-pkg/oidc" +) + +type splitWithEscapingTest struct { + // Name of the subtest. + name string + + // string to split + s string + + // seperator to use + seperator string + + // escape character to use for escaping + escape string + + expectedParts []string +} + +func (swet splitWithEscapingTest) run(t *testing.T) { + parts := oidc.SplitWithEscaping(swet.s, swet.seperator, swet.escape) + if len(swet.expectedParts) != len(parts) { + t.Errorf("mismatching length") + } + for i, v := range swet.expectedParts { + if parts[i] != v { + t.Errorf("expected part %d to be '%s', got '%s'", i, v, parts[i]) + } + } +} + +func TestSplitWithEscaping(t *testing.T) { + tests := []splitWithEscapingTest{ + { + name: "plain claim name", + s: "roles", + seperator: ".", + escape: "\\", + expectedParts: []string{"roles"}, + }, + { + name: "claim with .", + s: "my.roles", + seperator: ".", + escape: "\\", + expectedParts: []string{"my", "roles"}, + }, + { + name: "claim with escaped .", + s: "my\\.roles", + seperator: ".", + escape: "\\", + expectedParts: []string{"my.roles"}, + }, + { + name: "claim with escaped . left", + s: "my\\.other.roles", + seperator: ".", + escape: "\\", + expectedParts: []string{"my.other", "roles"}, + }, + { + name: "claim with escaped . right", + s: "my.other\\.roles", + seperator: ".", + escape: "\\", + expectedParts: []string{"my", "other.roles"}, + }, + } + for _, test := range tests { + t.Run(test.name, test.run) + } +} + +type walkSegmentsTest struct { + // Name of the subtest. + name string + + // path segments to walk + segments []string + + // seperator to use + claims map[string]interface{} + + expected interface{} + + wantErr bool +} + +func (wst walkSegmentsTest) run(t *testing.T) { + v, err := oidc.WalkSegments(wst.segments, wst.claims) + if err != nil && !wst.wantErr { + t.Errorf("%v", err) + } + if err == nil && wst.wantErr { + t.Errorf("expected error") + } + if !reflect.DeepEqual(v, wst.expected) { + t.Errorf("expected %v got %v", wst.expected, v) + } +} + +func TestWalkSegments(t *testing.T) { + byt := []byte(`{"first":{"second":{"third":["value1","value2"]},"foo":"bar"},"fizz":"buzz"}`) + var dat map[string]interface{} + if err := json.Unmarshal(byt, &dat); err != nil { + t.Errorf("%v", err) + } + + tests := []walkSegmentsTest{ + { + name: "one segment, single value", + segments: []string{"first"}, + claims: map[string]interface{}{ + "first": "value", + }, + expected: "value", + wantErr: false, + }, + { + name: "one segment, array value", + segments: []string{"first"}, + claims: map[string]interface{}{ + "first": []string{"value1", "value2"}, + }, + expected: []string{"value1", "value2"}, + wantErr: false, + }, + { + name: "two segments, single value", + segments: []string{"first", "second"}, + claims: map[string]interface{}{ + "first": map[string]interface{}{ + "second": "value", + }, + }, + expected: "value", + wantErr: false, + }, + { + name: "two segments, array value", + segments: []string{"first", "second"}, + claims: map[string]interface{}{ + "first": map[string]interface{}{ + "second": []string{"value1", "value2"}, + }, + }, + expected: []string{"value1", "value2"}, + wantErr: false, + }, + { + name: "three segments, array value from json", + segments: []string{"first", "second", "third"}, + claims: dat, + expected: []interface{}{"value1", "value2"}, + wantErr: false, + }, + { + name: "three segments, array value with interface key", + segments: []string{"first", "second", "third"}, + claims: map[string]interface{}{ + "first": map[interface{}]interface{}{ + "second": map[interface{}]interface{}{ + "third": []string{"value1", "value2"}, + }, + }, + }, + expected: []string{"value1", "value2"}, + wantErr: false, + }, + } + for _, test := range tests { + t.Run(test.name, test.run) + } +} diff --git a/services/proxy/pkg/middleware/account_resolver.go b/services/proxy/pkg/middleware/account_resolver.go index 0a9079091ce..79fd1ff8b4e 100644 --- a/services/proxy/pkg/middleware/account_resolver.go +++ b/services/proxy/pkg/middleware/account_resolver.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/http" - "strings" "github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend" "github.com/owncloud/ocis/v2/services/proxy/pkg/userroles" @@ -43,19 +42,6 @@ type accountResolver struct { userCS3Claim string } -// from https://codereview.stackexchange.com/a/280193 -func splitWithEscaping(s string, separator string, escapeString string) []string { - a := strings.Split(s, separator) - - for i := len(a) - 2; i >= 0; i-- { - if strings.HasSuffix(a[i], escapeString) { - a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1] - a = append(a[:i+1], a[i+2:]...) - } - } - return a -} - func readUserIDClaim(path string, claims map[string]interface{}) (string, error) { // happy path value, _ := claims[path].(string) @@ -64,7 +50,7 @@ func readUserIDClaim(path string, claims map[string]interface{}) (string, error) } // try splitting path at . - segments := splitWithEscaping(path, ".", "\\") + segments := oidc.SplitWithEscaping(path, ".", "\\") subclaims := claims lastSegment := len(segments) - 1 for i := range segments { diff --git a/services/proxy/pkg/userroles/oidcroles.go b/services/proxy/pkg/userroles/oidcroles.go index 3a94b2067b3..edc1ad327f8 100644 --- a/services/proxy/pkg/userroles/oidcroles.go +++ b/services/proxy/pkg/userroles/oidcroles.go @@ -9,6 +9,7 @@ import ( cs3 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" "github.com/cs3org/reva/v2/pkg/utils" "github.com/owncloud/ocis/v2/ocis-pkg/middleware" + "github.com/owncloud/ocis/v2/ocis-pkg/oidc" settingssvc "github.com/owncloud/ocis/v2/protogen/gen/ocis/services/settings/v0" "go-micro.dev/v4/metadata" ) @@ -29,6 +30,45 @@ func NewOIDCRoleAssigner(opts ...Option) UserRoleAssigner { } } +func extractRoles(rolesClaim string, claims map[string]interface{}) (map[string]struct{}, error) { + + claimRoles := map[string]struct{}{} + // happy path + value, _ := claims[rolesClaim].(string) + if value != "" { + claimRoles[value] = struct{}{} + return claimRoles, nil + } + + claim, err := oidc.WalkSegments(oidc.SplitWithEscaping(rolesClaim, ".", "\\"), claims) + if err != nil { + return nil, err + } + + switch v := claim.(type) { + case []string: + for _, cr := range v { + claimRoles[cr] = struct{}{} + } + case []interface{}: + for _, cri := range v { + cr, ok := cri.(string) + if !ok { + err := errors.New("invalid role in claims") + return nil, err + } + + claimRoles[cr] = struct{}{} + } + case string: + claimRoles[v] = struct{}{} + default: + return nil, errors.New("no roles in user claims") + } + + return claimRoles, nil +} + // UpdateUserRoleAssignment assigns the role "User" to the supplied user. Unless the user // already has a different role assigned. func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *cs3.User, claims map[string]interface{}) (*cs3.User, error) { @@ -39,23 +79,10 @@ func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *c return nil, err } - claimRolesRaw, ok := claims[ra.rolesClaim].([]interface{}) - if !ok { - logger.Error().Str("rolesClaim", ra.rolesClaim).Msg("No roles in user claims") - return nil, errors.New("no roles in user claims") - } - - logger.Debug().Str("rolesClaim", ra.rolesClaim).Interface("rolesInClaim", claims[ra.rolesClaim]).Msg("got roles in claim") - claimRoles := map[string]struct{}{} - for _, cri := range claimRolesRaw { - cr, ok := cri.(string) - if !ok { - err := errors.New("invalid role in claims") - logger.Error().Err(err).Interface("claimValue", cri).Msg("Is not a valid string.") - return nil, err - } - - claimRoles[cr] = struct{}{} + claimRoles, err := extractRoles(ra.rolesClaim, claims) + if err != nil { + logger.Error().Err(err).Msg("Error mapping role names to role ids") + return nil, err } if len(claimRoles) == 0 { diff --git a/services/proxy/pkg/userroles/oidcroles_test.go b/services/proxy/pkg/userroles/oidcroles_test.go new file mode 100644 index 00000000000..879b4edb2d7 --- /dev/null +++ b/services/proxy/pkg/userroles/oidcroles_test.go @@ -0,0 +1,120 @@ +package userroles + +import ( + "encoding/json" + "testing" +) + +func TestExtractRolesArray(t *testing.T) { + byt := []byte(`{"roles":["a","b"]}`) + + claims := map[string]interface{}{} + err := json.Unmarshal(byt, &claims) + if err != nil { + t.Fatal(err) + } + + roles, err := extractRoles("roles", claims) + if err != nil { + t.Fatal(err) + } + if _, ok := roles["a"]; !ok { + t.Fatal("must contain 'a'") + } + if _, ok := roles["b"]; !ok { + t.Fatal("must contain 'b'") + } +} + +func TestExtractRolesString(t *testing.T) { + byt := []byte(`{"roles":"a"}`) + + claims := map[string]interface{}{} + err := json.Unmarshal(byt, &claims) + if err != nil { + t.Fatal(err) + } + + roles, err := extractRoles("roles", claims) + if err != nil { + t.Fatal(err) + } + if _, ok := roles["a"]; !ok { + t.Fatal("must contain 'a'") + } +} + +func TestExtractRolesPathArray(t *testing.T) { + byt := []byte(`{"sub":{"roles":["a","b"]}}`) + + claims := map[string]interface{}{} + err := json.Unmarshal(byt, &claims) + if err != nil { + t.Fatal(err) + } + + roles, err := extractRoles("sub.roles", claims) + if err != nil { + t.Fatal(err) + } + if _, ok := roles["a"]; !ok { + t.Fatal("must contain 'a'") + } + if _, ok := roles["b"]; !ok { + t.Fatal("must contain 'b'") + } +} + +func TestExtractRolesPathString(t *testing.T) { + byt := []byte(`{"sub":{"roles":"a"}}`) + + claims := map[string]interface{}{} + err := json.Unmarshal(byt, &claims) + if err != nil { + t.Fatal(err) + } + + roles, err := extractRoles("sub.roles", claims) + if err != nil { + t.Fatal(err) + } + if _, ok := roles["a"]; !ok { + t.Fatal("must contain 'a'") + } +} + +func TestExtractEscapedRolesPathString(t *testing.T) { + byt := []byte(`{"sub.roles":"a"}`) + + claims := map[string]interface{}{} + err := json.Unmarshal(byt, &claims) + if err != nil { + t.Fatal(err) + } + + roles, err := extractRoles("sub\\.roles", claims) + if err != nil { + t.Fatal(err) + } + if _, ok := roles["a"]; !ok { + t.Fatal("must contain 'a'") + } +} + +func TestNoRoles(t *testing.T) { + byt := []byte(`{"sub":{"foo":"a"}}`) + + claims := map[string]interface{}{} + err := json.Unmarshal(byt, &claims) + if err != nil { + t.Fatal(err) + } + + roles, err := extractRoles("sub.roles", claims) + if err == nil { + t.Fatal("must not find a role") + } + if len(roles) != 0 { + t.Fatal("length of roles mut be 0") + } +}