Skip to content

Commit

Permalink
InMemoryChannel: reject request for wrong audience (#7449)
Browse files Browse the repository at this point in the history
* Implement the auth check in the event receiver

* Reject request for wrong audience in IMC event_receiver (#1)

* Running the rekt test

Co-authored-by: Christoph Stäbler <cstabler@redhat.com>

* Update pkg/channel/fanout/fanout_event_handler.go

Co-authored-by: Calum Murray <cmurray@redhat.com>

* Update the variable naming

Signed-off-by: Leo Li <leoli@redhat.com>

* Refactor the jwt event auth header verification

Signed-off-by: Leo Li <leoli@redhat.com>

* Revert the refactoring

Signed-off-by: Leo Li <leoli@redhat.com>

* Fix the linter issue

Signed-off-by: Leo Li <leoli@redhat.com>

---------

Signed-off-by: Leo Li <leoli@redhat.com>
Co-authored-by: Christoph Stäbler <cstabler@redhat.com>
Co-authored-by: Calum Murray <cmurray@redhat.com>
  • Loading branch information
3 people authored Dec 5, 2023
1 parent aa0d7bb commit 83125a9
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 17 deletions.
51 changes: 51 additions & 0 deletions pkg/channel/event_receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import (
nethttp "net/http"
"time"

"knative.dev/eventing/pkg/apis/feature"

"knative.dev/eventing/pkg/auth"

"github.com/cloudevents/sdk-go/v2/event"
"github.com/cloudevents/sdk-go/v2/protocol/http"
"go.uber.org/zap"
Expand Down Expand Up @@ -65,6 +69,9 @@ type EventReceiver struct {
hostToChannelFunc ResolveChannelFromHostFunc
pathToChannelFunc ResolveChannelFromPathFunc
reporter StatsReporter
tokenVerifier *auth.OIDCTokenVerifier
audience string
withContext func(context.Context) context.Context
}

// EventReceiverFunc is the function to be called for handling the event.
Expand Down Expand Up @@ -100,6 +107,21 @@ func ResolveChannelFromPath(PathToChannelFunc ResolveChannelFromPathFunc) EventR
}
}

func OIDCTokenVerification(tokenVerifier *auth.OIDCTokenVerifier, audience string) EventReceiverOptions {
return func(r *EventReceiver) error {
r.tokenVerifier = tokenVerifier
r.audience = audience
return nil
}
}

func ReceiverWithContextFunc(fn func(context.Context) context.Context) EventReceiverOptions {
return func(r *EventReceiver) error {
r.withContext = fn
return nil
}
}

// NewEventReceiver creates an event receiver passing new events to the
// receiverFunc.
func NewEventReceiver(receiverFunc EventReceiverFunc, logger *zap.Logger, reporter StatsReporter, opts ...EventReceiverOptions) (*EventReceiver, error) {
Expand Down Expand Up @@ -153,6 +175,12 @@ func (r *EventReceiver) Start(ctx context.Context) error {
}

func (r *EventReceiver) ServeHTTP(response nethttp.ResponseWriter, request *nethttp.Request) {
ctx := request.Context()

if r.withContext != nil {
ctx = r.withContext(ctx)
}

response.Header().Set("Allow", "POST, OPTIONS")
if request.Method == nethttp.MethodOptions {
response.Header().Set("WebHook-Allowed-Origin", "*") // Accept from any Origin:
Expand Down Expand Up @@ -218,6 +246,29 @@ func (r *EventReceiver) ServeHTTP(response nethttp.ResponseWriter, request *neth
return
}

/// Here we do the OIDC audience verification
features := feature.FromContext(ctx)
if features.IsOIDCAuthentication() {
r.logger.Debug("OIDC authentication is enabled")

token := auth.GetJWTFromHeader(request.Header)
if token == "" {
r.logger.Warn(fmt.Sprintf("No JWT in %s header provided while feature %s is enabled", auth.AuthHeaderKey, feature.OIDCAuthentication))
response.WriteHeader(nethttp.StatusUnauthorized)
return
}

if _, err := r.tokenVerifier.VerifyJWT(ctx, token, r.audience); err != nil {
r.logger.Warn("no valid JWT provided", zap.Error(err))
response.WriteHeader(nethttp.StatusUnauthorized)
return
}

r.logger.Debug("Request contained a valid JWT. Continuing...")
} else {
r.logger.Debug("OIDC authentication is disabled")
}

err = r.receiverFunc(request.Context(), channel, *event, utils.PassThroughHeaders(request.Header))
if err != nil {
if _, ok := err.(*UnknownChannelError); ok {
Expand Down
33 changes: 17 additions & 16 deletions pkg/channel/fanout/fanout_event_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ type FanoutEventHandler struct {
// rather than a member variable.
timeout time.Duration

reporter channel.StatsReporter
logger *zap.Logger
eventTypeHandler *eventtype.EventTypeAutoHandler
channelAddressable *duckv1.KReference
channelUID *types.UID
reporter channel.StatsReporter
logger *zap.Logger
eventTypeHandler *eventtype.EventTypeAutoHandler
channelRef *duckv1.KReference
channelUID *types.UID
}

// NewFanoutEventHandler creates a new fanout.EventHandler.
Expand All @@ -101,20 +101,21 @@ func NewFanoutEventHandler(
config Config,
reporter channel.StatsReporter,
eventTypeHandler *eventtype.EventTypeAutoHandler,
channelAddressable *duckv1.KReference,
channelRef *duckv1.KReference,
channelUID *types.UID,
eventDispatcher *kncloudevents.Dispatcher,
receiverOpts ...channel.EventReceiverOptions,

) (*FanoutEventHandler, error) {
handler := &FanoutEventHandler{
logger: logger,
timeout: defaultTimeout,
reporter: reporter,
asyncHandler: config.AsyncHandler,
eventTypeHandler: eventTypeHandler,
channelAddressable: channelAddressable,
channelUID: channelUID,
eventDispatcher: eventDispatcher,
logger: logger,
timeout: defaultTimeout,
reporter: reporter,
asyncHandler: config.AsyncHandler,
eventTypeHandler: eventTypeHandler,
channelRef: channelRef,
channelUID: channelUID,
eventDispatcher: eventDispatcher,
}
handler.subscriptions = make([]Subscription, len(config.Subscriptions))
copy(handler.subscriptions, config.Subscriptions)
Expand Down Expand Up @@ -184,15 +185,15 @@ func (f *FanoutEventHandler) GetSubscriptions(ctx context.Context) []Subscriptio
}

func (f *FanoutEventHandler) autoCreateEventType(ctx context.Context, evnt event.Event) {
if f.channelAddressable == nil {
if f.channelRef == nil {
f.logger.Warn("No addressable for channel")
return
} else {
if f.channelUID == nil {
f.logger.Warn("No channelUID provided, unable to autocreate event type")
return
}
err := f.eventTypeHandler.AutoCreateEventType(ctx, &evnt, f.channelAddressable, *f.channelUID)
err := f.eventTypeHandler.AutoCreateEventType(ctx, &evnt, f.channelRef, *f.channelUID)
if err != nil {
f.logger.Warn("EventTypeCreate failed")
return
Expand Down
1 change: 1 addition & 0 deletions pkg/reconciler/inmemorychannel/dispatcher/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func NewController(
eventingClient: eventingclient.Get(ctx).EventingV1beta2(),
eventTypeLister: eventtypeinformer.Get(ctx).Lister(),
eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider),
tokenVerifier: auth.NewOIDCTokenVerifier(ctx),
}

var globalResync func(obj interface{})
Expand Down
14 changes: 14 additions & 0 deletions pkg/reconciler/inmemorychannel/dispatcher/inmemorychannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
eventingduckv1 "knative.dev/eventing/pkg/apis/duck/v1"
"knative.dev/eventing/pkg/apis/feature"
v1 "knative.dev/eventing/pkg/apis/messaging/v1"
"knative.dev/eventing/pkg/auth"
"knative.dev/eventing/pkg/channel"
"knative.dev/eventing/pkg/channel/fanout"
"knative.dev/eventing/pkg/channel/multichannelfanout"
Expand All @@ -57,6 +58,7 @@ type Reconciler struct {
eventingClient eventingv1beta2.EventingV1beta2Interface
featureStore *feature.Store
eventDispatcher *kncloudevents.Dispatcher
tokenVerifier *auth.OIDCTokenVerifier
}

// Check the interfaces Reconciler should implement
Expand Down Expand Up @@ -111,6 +113,10 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec
UID = &imc.UID
}

wc := func(ctx context.Context) context.Context {
return r.featureStore.ToContext(ctx)
}

// First grab the host based MultiChannelFanoutMessage httpHandler
httpHandler := r.multiChannelEventHandler.GetChannelHandler(config.HostName)
if httpHandler == nil {
Expand All @@ -123,6 +129,8 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec
channelRef,
UID,
r.eventDispatcher,
channel.OIDCTokenVerification(r.tokenVerifier, audience(imc)),
channel.ReceiverWithContextFunc(wc),
)
if err != nil {
logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err)
Expand Down Expand Up @@ -153,6 +161,8 @@ func (r *Reconciler) reconcile(ctx context.Context, imc *v1.InMemoryChannel) rec
UID,
r.eventDispatcher,
channel.ResolveChannelFromPath(channel.ParseChannelFromPath),
channel.OIDCTokenVerification(r.tokenVerifier, audience(imc)),
channel.ReceiverWithContextFunc(wc),
)
if err != nil {
logging.FromContext(ctx).Error("Failed to create a new fanout.EventHandler", err)
Expand Down Expand Up @@ -295,3 +305,7 @@ func toKReference(imc *v1.InMemoryChannel) *duckv1.KReference {
Address: imc.Status.Address.Name,
}
}

func audience(imc *v1.InMemoryChannel) string {
return auth.GetAudience(v1.SchemeGroupVersion.WithKind("InMemoryChannel"), imc.ObjectMeta)
}
2 changes: 1 addition & 1 deletion test/auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func TestChannelImplSupportsOIDC(t *testing.T) {
name := feature.MakeRandomK8sName("channelimpl")
env.Prerequisite(ctx, t, channel.ImplGoesReady(name))

env.Test(ctx, t, oidc.AddressableHasAudiencePopulated(channel_impl.GVR(), channel_impl.GVK().Kind, name, env.Namespace()))
env.TestSet(ctx, t, oidc.AddressableOIDCConformance(channel_impl.GVR(), channel_impl.GVK().Kind, name, env.Namespace()))
}

func TestParallelSupportsOIDC(t *testing.T) {
Expand Down

0 comments on commit 83125a9

Please sign in to comment.