diff --git a/auth/context.go b/auth/context.go new file mode 100644 index 00000000000..c5cee929ffe --- /dev/null +++ b/auth/context.go @@ -0,0 +1,65 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import "context" + +var ( + rawKey = rawType{} + subjectKey = subjectType{} + groupsKey = groupsType{} +) + +type rawType struct{} +type subjectType struct{} +type groupsType struct{} + +// NewContextFromRaw creates a new context derived from the given context, +// adding the raw authentication string to the result. +func NewContextFromRaw(ctx context.Context, raw string) context.Context { + return context.WithValue(ctx, rawKey, raw) +} + +// NewContextFromSubject creates a new context derived from the given context, +// adding the subject to the result. +func NewContextFromSubject(ctx context.Context, subject string) context.Context { + return context.WithValue(ctx, subjectKey, subject) +} + +// NewContextFromMemberships creates a new context derived from the given context, +// adding the memberships to the result. +func NewContextFromMemberships(ctx context.Context, groups []string) context.Context { + return context.WithValue(ctx, groupsKey, groups) +} + +// RawFromContext returns the raw authentication string used to perform the authentication. +// It's typically the string value for a single HTTP header, such as the "Authentication" header. +// Example: "Basic ZGV2ZWxvcGVyOmN1cmlvdXM=". +func RawFromContext(ctx context.Context) (string, bool) { + value, ok := ctx.Value(rawKey).(string) + return value, ok +} + +// SubjectFromContext returns the subject that was extracted from the raw authentication string. +func SubjectFromContext(ctx context.Context) (string, bool) { + value, ok := ctx.Value(subjectKey).(string) + return value, ok +} + +// GroupsFromContext returns the list of groups the subject belongs. +func GroupsFromContext(ctx context.Context) ([]string, bool) { + value, ok := ctx.Value(groupsKey).([]string) + return value, ok +} diff --git a/config/configauth/mock_serverauth.go b/config/configauth/mock_serverauth.go index cd11c34f74d..f2522557200 100644 --- a/config/configauth/mock_serverauth.go +++ b/config/configauth/mock_serverauth.go @@ -35,9 +35,9 @@ type MockAuthenticator struct { } // Authenticate executes the mock's AuthenticateFunc, if provided, or just returns the given context unchanged. -func (m *MockAuthenticator) Authenticate(ctx context.Context, headers map[string][]string) error { +func (m *MockAuthenticator) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) { if m.AuthenticateFunc == nil { - return nil + return context.Background(), nil } return m.AuthenticateFunc(ctx, headers) } diff --git a/config/configauth/mock_serverauth_test.go b/config/configauth/mock_serverauth_test.go index c22852862b0..329a406c907 100644 --- a/config/configauth/mock_serverauth_test.go +++ b/config/configauth/mock_serverauth_test.go @@ -25,17 +25,18 @@ func TestAuthenticateFunc(t *testing.T) { // prepare m := &MockAuthenticator{} called := false - m.AuthenticateFunc = func(c context.Context, m map[string][]string) error { + m.AuthenticateFunc = func(c context.Context, m map[string][]string) (context.Context, error) { called = true - return nil + return context.Background(), nil } // test - err := m.Authenticate(context.Background(), nil) + ctx, err := m.Authenticate(context.Background(), nil) // verify assert.NoError(t, err) assert.True(t, called) + assert.NotNil(t, ctx) } func TestNilOperations(t *testing.T) { @@ -46,8 +47,9 @@ func TestNilOperations(t *testing.T) { origCtx := context.Background() { - err := m.Authenticate(origCtx, nil) + ctx, err := m.Authenticate(origCtx, nil) assert.NoError(t, err) + assert.NotNil(t, ctx) } { diff --git a/config/configauth/serverauth.go b/config/configauth/serverauth.go index c2c90102444..2a6da1e378b 100644 --- a/config/configauth/serverauth.go +++ b/config/configauth/serverauth.go @@ -40,7 +40,11 @@ type ServerAuthenticator interface { // When the authentication fails, an error must be returned and the caller must not retry. This function is typically called from interceptors, // on behalf of receivers, but receivers can still call this directly if the usage of interceptors isn't suitable. // The deadline and cancellation given to this function must be respected, but note that authentication data has to be part of the map, not context. - Authenticate(ctx context.Context, headers map[string][]string) error + // The resulting context should contain the authentication data, such as the principal/username, group membership (if available), and the raw + // authentication data (if possible). This will allow other components in the pipeline to make decisions based on that data, such as routing based + // on tenancy as determined by the group membership, or passing through the authentication data to the next collector/backend. + // The context keys to be used are not defined yet. + Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) // GRPCUnaryServerInterceptor is a helper method to provide a gRPC-compatible UnaryServerInterceptor, typically calling the authenticator's Authenticate method. // While the context is the typical source of authentication data, the interceptor is free to determine where the auth data should come from. For instance, some @@ -59,7 +63,7 @@ type ServerAuthenticator interface { // AuthenticateFunc defines the signature for the function responsible for performing the authentication based on the given headers map. // See ServerAuthenticator.Authenticate. -type AuthenticateFunc func(ctx context.Context, headers map[string][]string) error +type AuthenticateFunc func(ctx context.Context, headers map[string][]string) (context.Context, error) // GRPCUnaryInterceptorFunc defines the signature for the function intercepting unary gRPC calls, useful for authenticators to use as // types for internal structs, making it easier to mock them in tests. @@ -79,7 +83,8 @@ func DefaultGRPCUnaryServerInterceptor(ctx context.Context, req interface{}, _ * return nil, errMetadataNotFound } - if err := authenticate(ctx, headers); err != nil { + ctx, err := authenticate(ctx, headers) + if err != nil { return nil, err } @@ -95,7 +100,9 @@ func DefaultGRPCStreamServerInterceptor(srv interface{}, stream grpc.ServerStrea return errMetadataNotFound } - if err := authenticate(ctx, headers); err != nil { + // TODO: propagate the context down the stream + _, err := authenticate(ctx, headers) + if err != nil { return err } diff --git a/config/configauth/serverauth_test.go b/config/configauth/serverauth_test.go index d8314973730..4a37c9ba284 100644 --- a/config/configauth/serverauth_test.go +++ b/config/configauth/serverauth_test.go @@ -28,9 +28,9 @@ func TestDefaultUnaryInterceptorAuthSucceeded(t *testing.T) { // prepare handlerCalled := false authCalled := false - authFunc := func(context.Context, map[string][]string) error { + authFunc := func(context.Context, map[string][]string) (context.Context, error) { authCalled = true - return nil + return context.Background(), nil } handler := func(ctx context.Context, req interface{}) (interface{}, error) { handlerCalled = true @@ -52,9 +52,9 @@ func TestDefaultUnaryInterceptorAuthFailure(t *testing.T) { // prepare authCalled := false expectedErr := fmt.Errorf("not authenticated") - authFunc := func(context.Context, map[string][]string) error { + authFunc := func(context.Context, map[string][]string) (context.Context, error) { authCalled = true - return expectedErr + return context.Background(), expectedErr } handler := func(ctx context.Context, req interface{}) (interface{}, error) { assert.FailNow(t, "the handler should not have been called on auth failure!") @@ -73,9 +73,9 @@ func TestDefaultUnaryInterceptorAuthFailure(t *testing.T) { func TestDefaultUnaryInterceptorMissingMetadata(t *testing.T) { // prepare - authFunc := func(context.Context, map[string][]string) error { + authFunc := func(context.Context, map[string][]string) (context.Context, error) { assert.FailNow(t, "the auth func should not have been called!") - return nil + return context.Background(), nil } handler := func(ctx context.Context, req interface{}) (interface{}, error) { assert.FailNow(t, "the handler should not have been called!") @@ -94,9 +94,9 @@ func TestDefaultStreamInterceptorAuthSucceeded(t *testing.T) { // prepare handlerCalled := false authCalled := false - authFunc := func(context.Context, map[string][]string) error { + authFunc := func(context.Context, map[string][]string) (context.Context, error) { authCalled = true - return nil + return context.Background(), nil } handler := func(srv interface{}, stream grpc.ServerStream) error { handlerCalled = true @@ -120,9 +120,9 @@ func TestDefaultStreamInterceptorAuthFailure(t *testing.T) { // prepare authCalled := false expectedErr := fmt.Errorf("not authenticated") - authFunc := func(context.Context, map[string][]string) error { + authFunc := func(context.Context, map[string][]string) (context.Context, error) { authCalled = true - return expectedErr + return context.Background(), expectedErr } handler := func(srv interface{}, stream grpc.ServerStream) error { assert.FailNow(t, "the handler should not have been called on auth failure!") @@ -143,9 +143,9 @@ func TestDefaultStreamInterceptorAuthFailure(t *testing.T) { func TestDefaultStreamInterceptorMissingMetadata(t *testing.T) { // prepare - authFunc := func(context.Context, map[string][]string) error { + authFunc := func(context.Context, map[string][]string) (context.Context, error) { assert.FailNow(t, "the auth func should not have been called!") - return nil + return context.Background(), nil } handler := func(srv interface{}, stream grpc.ServerStream) error { assert.FailNow(t, "the handler should not have been called!") diff --git a/extension/oidcauthextension/extension.go b/extension/oidcauthextension/extension.go index 59d314892d9..bf2918b3ab3 100644 --- a/extension/oidcauthextension/extension.go +++ b/extension/oidcauthextension/extension.go @@ -32,6 +32,7 @@ import ( "go.uber.org/zap" "google.golang.org/grpc" + "go.opentelemetry.io/collector/auth" "go.opentelemetry.io/collector/component" "go.opentelemetry.io/collector/config/configauth" ) @@ -100,21 +101,23 @@ func (e *oidcExtension) Shutdown(context.Context) error { } // Authenticate checks whether the given context contains valid auth data. Successfully authenticated calls will always return a nil error and a context with the auth data. -func (e *oidcExtension) Authenticate(ctx context.Context, headers map[string][]string) error { +func (e *oidcExtension) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) { authHeaders := headers[e.cfg.Attribute] if len(authHeaders) == 0 { - return errNotAuthenticated + return ctx, errNotAuthenticated } + raw := authHeaders[0] + // we only use the first header, if multiple values exist - parts := strings.Split(authHeaders[0], " ") + parts := strings.Split(raw, " ") if len(parts) != 2 { - return errInvalidAuthenticationHeaderFormat + return ctx, errInvalidAuthenticationHeaderFormat } idToken, err := e.verifier.Verify(ctx, parts[1]) if err != nil { - return fmt.Errorf("failed to verify token: %w", err) + return ctx, fmt.Errorf("failed to verify token: %w", err) } claims := map[string]interface{}{} @@ -125,20 +128,26 @@ func (e *oidcExtension) Authenticate(ctx context.Context, headers map[string][]s // to read the claims. It could fail if we were using a custom struct. Instead of // swalling the error, it's better to make this future-proof, in case the underlying // code changes - return errFailedToObtainClaimsFromToken + return ctx, errFailedToObtainClaimsFromToken } - _, err = getSubjectFromClaims(claims, e.cfg.UsernameClaim, idToken.Subject) + // we could have set this right after obtaining the raw auth string, but we should only change the + // context if the auth was successful + ctx = auth.NewContextFromRaw(ctx, raw) + + sub, err := getSubjectFromClaims(claims, e.cfg.UsernameClaim, idToken.Subject) if err != nil { - return fmt.Errorf("failed to get subject from claims in the token: %w", err) + return ctx, fmt.Errorf("failed to get subject from claims in the token: %w", err) } + ctx = auth.NewContextFromSubject(ctx, sub) - _, err = getGroupsFromClaims(claims, e.cfg.GroupsClaim) + groups, err := getGroupsFromClaims(claims, e.cfg.GroupsClaim) if err != nil { - return fmt.Errorf("failed to get groups from claims in the token: %w", err) + return ctx, fmt.Errorf("failed to get groups from claims in the token: %w", err) } + ctx = auth.NewContextFromMemberships(ctx, groups) - return nil + return ctx, nil } // GRPCUnaryServerInterceptor is a helper method to provide a gRPC-compatible UnaryInterceptor, typically calling the authenticator's Authenticate method. diff --git a/extension/oidcauthextension/extension_test.go b/extension/oidcauthextension/extension_test.go index da791366663..0ab33f9663f 100644 --- a/extension/oidcauthextension/extension_test.go +++ b/extension/oidcauthextension/extension_test.go @@ -69,10 +69,11 @@ func TestOIDCAuthenticationSucceeded(t *testing.T) { require.NoError(t, err) // test - err = p.Authenticate(context.Background(), map[string][]string{"authorization": {fmt.Sprintf("Bearer %s", token)}}) + ctx, err := p.Authenticate(context.Background(), map[string][]string{"authorization": {fmt.Sprintf("Bearer %s", token)}}) // verify assert.NoError(t, err) + assert.NotNil(t, ctx) // TODO(jpkroehling): assert that the authentication routine set the subject/membership to the resource } @@ -209,10 +210,11 @@ func TestOIDCInvalidAuthHeader(t *testing.T) { require.NoError(t, err) // test - err = p.Authenticate(context.Background(), map[string][]string{"authorization": {"some-value"}}) + ctx, err := p.Authenticate(context.Background(), map[string][]string{"authorization": {"some-value"}}) // verify assert.Equal(t, errInvalidAuthenticationHeaderFormat, err) + assert.NotNil(t, ctx) } func TestOIDCNotAuthenticated(t *testing.T) { @@ -224,10 +226,11 @@ func TestOIDCNotAuthenticated(t *testing.T) { require.NoError(t, err) // test - err = p.Authenticate(context.Background(), make(map[string][]string)) + ctx, err := p.Authenticate(context.Background(), make(map[string][]string)) // verify assert.Equal(t, errNotAuthenticated, err) + assert.NotNil(t, ctx) } func TestProviderNotReacheable(t *testing.T) { @@ -262,10 +265,11 @@ func TestFailedToVerifyToken(t *testing.T) { require.NoError(t, err) // test - err = p.Authenticate(context.Background(), map[string][]string{"authorization": {"Bearer some-token"}}) + ctx, err := p.Authenticate(context.Background(), map[string][]string{"authorization": {"Bearer some-token"}}) // verify assert.Error(t, err) + assert.NotNil(t, ctx) } func TestFailedToGetGroupsClaimFromToken(t *testing.T) { @@ -325,10 +329,11 @@ func TestFailedToGetGroupsClaimFromToken(t *testing.T) { require.NoError(t, err) // test - err = p.Authenticate(context.Background(), map[string][]string{"authorization": {fmt.Sprintf("Bearer %s", token)}}) + ctx, err := p.Authenticate(context.Background(), map[string][]string{"authorization": {fmt.Sprintf("Bearer %s", token)}}) // verify assert.ErrorIs(t, err, tt.expectedError) + assert.NotNil(t, ctx) }) } }