diff --git a/pkg/channel/event_receiver.go b/pkg/channel/event_receiver.go index 86ff96cf146..f39e8facc32 100644 --- a/pkg/channel/event_receiver.go +++ b/pkg/channel/event_receiver.go @@ -23,6 +23,10 @@ import ( nethttp "net/http" "time" + "knative.dev/eventing/pkg/apis/feature" + + "knative.dev/eventing/pkg/auth" + "github.com/cloudevents/sdk-go/v2/event" "github.com/cloudevents/sdk-go/v2/protocol/http" "go.uber.org/zap" @@ -65,6 +69,9 @@ type EventReceiver struct { hostToChannelFunc ResolveChannelFromHostFunc pathToChannelFunc ResolveChannelFromPathFunc reporter StatsReporter + tokenVerifier *auth.OIDCTokenVerifier + audience string + withContext func(context.Context) context.Context } // EventReceiverFunc is the function to be called for handling the event. @@ -100,6 +107,21 @@ func ResolveChannelFromPath(PathToChannelFunc ResolveChannelFromPathFunc) EventR } } +func OIDCTokenVerification(tokenVerifier *auth.OIDCTokenVerifier, audience string) EventReceiverOptions { + return func(r *EventReceiver) error { + r.tokenVerifier = tokenVerifier + r.audience = audience + return nil + } +} + +func ReceiverWithContextFunc(fn func(context.Context) context.Context) EventReceiverOptions { + return func(r *EventReceiver) error { + r.withContext = fn + return nil + } +} + // NewEventReceiver creates an event receiver passing new events to the // receiverFunc. func NewEventReceiver(receiverFunc EventReceiverFunc, logger *zap.Logger, reporter StatsReporter, opts ...EventReceiverOptions) (*EventReceiver, error) { @@ -153,6 +175,12 @@ func (r *EventReceiver) Start(ctx context.Context) error { } func (r *EventReceiver) ServeHTTP(response nethttp.ResponseWriter, request *nethttp.Request) { + ctx := request.Context() + + if r.withContext != nil { + ctx = r.withContext(ctx) + } + response.Header().Set("Allow", "POST, OPTIONS") if request.Method == nethttp.MethodOptions { response.Header().Set("WebHook-Allowed-Origin", "*") // Accept from any Origin: @@ -218,6 +246,29 @@ func (r *EventReceiver) ServeHTTP(response nethttp.ResponseWriter, request *neth return } + /// Here we do the OIDC audience verification + features := feature.FromContext(ctx) + if features.IsOIDCAuthentication() { + r.logger.Debug("OIDC authentication is enabled") + + token := auth.GetJWTFromHeader(request.Header) + if token == "" { + r.logger.Warn(fmt.Sprintf("No JWT in %s header provided while feature %s is enabled", auth.AuthHeaderKey, feature.OIDCAuthentication)) + response.WriteHeader(nethttp.StatusUnauthorized) + return + } + + if _, err := r.tokenVerifier.VerifyJWT(ctx, token, r.audience); err != nil { + r.logger.Warn("no valid JWT provided", zap.Error(err)) + response.WriteHeader(nethttp.StatusUnauthorized) + return + } + + r.logger.Debug("Request contained a valid JWT. Continuing...") + } else { + r.logger.Debug("OIDC authentication is disabled") + } + err = r.receiverFunc(request.Context(), channel, *event, utils.PassThroughHeaders(request.Header)) if err != nil { if _, ok := err.(*UnknownChannelError); ok { diff --git a/pkg/channel/fanout/fanout_event_handler.go b/pkg/channel/fanout/fanout_event_handler.go index 0596ea51a8e..f59361fe398 100644 --- a/pkg/channel/fanout/fanout_event_handler.go +++ b/pkg/channel/fanout/fanout_event_handler.go @@ -88,11 +88,11 @@ type FanoutEventHandler struct { // rather than a member variable. timeout time.Duration - reporter channel.StatsReporter - logger *zap.Logger - eventTypeHandler *eventtype.EventTypeAutoHandler - channelAddressable *duckv1.KReference - channelUID *types.UID + reporter channel.StatsReporter + logger *zap.Logger + eventTypeHandler *eventtype.EventTypeAutoHandler + channelRef *duckv1.KReference + channelUID *types.UID } // NewFanoutEventHandler creates a new fanout.EventHandler. @@ -101,20 +101,21 @@ func NewFanoutEventHandler( config Config, reporter channel.StatsReporter, eventTypeHandler *eventtype.EventTypeAutoHandler, - channelAddressable *duckv1.KReference, + channelRef *duckv1.KReference, channelUID *types.UID, eventDispatcher *kncloudevents.Dispatcher, receiverOpts ...channel.EventReceiverOptions, + ) (*FanoutEventHandler, error) { handler := &FanoutEventHandler{ - logger: logger, - timeout: defaultTimeout, - reporter: reporter, - asyncHandler: config.AsyncHandler, - eventTypeHandler: eventTypeHandler, - channelAddressable: channelAddressable, - channelUID: channelUID, - eventDispatcher: eventDispatcher, + logger: logger, + timeout: defaultTimeout, + reporter: reporter, + asyncHandler: config.AsyncHandler, + eventTypeHandler: eventTypeHandler, + channelRef: channelRef, + channelUID: channelUID, + eventDispatcher: eventDispatcher, } handler.subscriptions = make([]Subscription, len(config.Subscriptions)) copy(handler.subscriptions, config.Subscriptions) @@ -184,7 +185,7 @@ func (f *FanoutEventHandler) GetSubscriptions(ctx context.Context) []Subscriptio } func (f *FanoutEventHandler) autoCreateEventType(ctx context.Context, evnt event.Event) { - if f.channelAddressable == nil { + if f.channelRef == nil { f.logger.Warn("No addressable for channel") return } else { @@ -192,7 +193,7 @@ func (f *FanoutEventHandler) autoCreateEventType(ctx context.Context, evnt event f.logger.Warn("No channelUID provided, unable to autocreate event type") return } - err := f.eventTypeHandler.AutoCreateEventType(ctx, &evnt, f.channelAddressable, *f.channelUID) + err := f.eventTypeHandler.AutoCreateEventType(ctx, &evnt, f.channelRef, *f.channelUID) if err != nil { f.logger.Warn("EventTypeCreate failed") return diff --git a/pkg/reconciler/inmemorychannel/dispatcher/controller.go b/pkg/reconciler/inmemorychannel/dispatcher/controller.go index e2e7dd8290d..1f6be14630c 100644 --- a/pkg/reconciler/inmemorychannel/dispatcher/controller.go +++ b/pkg/reconciler/inmemorychannel/dispatcher/controller.go @@ -128,6 +128,7 @@ func NewController( eventingClient: eventingclient.Get(ctx).EventingV1beta2(), eventTypeLister: eventtypeinformer.Get(ctx).Lister(), eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider), + tokenVerifier: auth.NewOIDCTokenVerifier(ctx), } var globalResync func(obj interface{}) diff --git a/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go b/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go index 02a5b84400d..65bfbfa2c95 100644 --- a/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go +++ b/pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go @@ -37,6 +37,7 @@ import ( eventingduckv1 "knative.dev/eventing/pkg/apis/duck/v1" "knative.dev/eventing/pkg/apis/feature" v1 "knative.dev/eventing/pkg/apis/messaging/v1" + "knative.dev/eventing/pkg/auth" "knative.dev/eventing/pkg/channel" "knative.dev/eventing/pkg/channel/fanout" "knative.dev/eventing/pkg/channel/multichannelfanout" @@ -57,6 +58,7 @@ type Reconciler struct { eventingClient eventingv1beta2.EventingV1beta2Interface featureStore *feature.Store eventDispatcher *kncloudevents.Dispatcher + tokenVerifier *auth.OIDCTokenVerifier } // Check the interfaces Reconciler should implement @@ -111,6 +113,10 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec UID = &imc.UID } + wc := func(ctx context.Context) context.Context { + return r.featureStore.ToContext(ctx) + } + // First grab the host based MultiChannelFanoutMessage httpHandler httpHandler := r.multiChannelEventHandler.GetChannelHandler(config.HostName) if httpHandler == nil { @@ -123,6 +129,8 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec channelRef, UID, r.eventDispatcher, + channel.OIDCTokenVerification(r.tokenVerifier, audience(imc)), + channel.ReceiverWithContextFunc(wc), ) if err != nil { logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err) @@ -153,6 +161,8 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec UID, r.eventDispatcher, channel.ResolveChannelFromPath(channel.ParseChannelFromPath), + channel.OIDCTokenVerification(r.tokenVerifier, audience(imc)), + channel.ReceiverWithContextFunc(wc), ) if err != nil { logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err) @@ -295,3 +305,7 @@ func toKReference(imc *v1.InMemoryChannel) *duckv1.KReference { Address: imc.Status.Address.Name, } } + +func audience(imc *v1.InMemoryChannel) string { + return auth.GetAudience(v1.SchemeGroupVersion.WithKind("InMemoryChannel"), imc.ObjectMeta) +} diff --git a/test/auth/oidc_test.go b/test/auth/oidc_test.go index 218db1bdcac..1379a183e26 100644 --- a/test/auth/oidc_test.go +++ b/test/auth/oidc_test.go @@ -88,7 +88,7 @@ func TestChannelImplSupportsOIDC(t *testing.T) { name := feature.MakeRandomK8sName("channelimpl") env.Prerequisite(ctx, t, channel.ImplGoesReady(name)) - env.Test(ctx, t, oidc.AddressableHasAudiencePopulated(channel_impl.GVR(), channel_impl.GVK().Kind, name, env.Namespace())) + env.TestSet(ctx, t, oidc.AddressableOIDCConformance(channel_impl.GVR(), channel_impl.GVK().Kind, name, env.Namespace())) } func TestParallelSupportsOIDC(t *testing.T) {