Skip to content

Commit

Permalink
fix: spurious cancelation of async webhooks, better tracing
Browse files Browse the repository at this point in the history
Previously, async webhooks (response.ignore=true) would be canceled
early once the incoming Kratos request was served and it's associated
context released. We now dissociate the cancellation of async hooks
from the normal request processing flow.
  • Loading branch information
alnr committed Dec 19, 2022
1 parent c2adc6b commit 134295e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
57 changes: 34 additions & 23 deletions selfservice/hook/web_hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (

"github.com/pkg/errors"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.11.0"
"go.opentelemetry.io/otel/trace"

"github.com/ory/kratos/ui/node"
Expand All @@ -30,7 +32,6 @@ import (
"github.com/ory/kratos/session"
"github.com/ory/kratos/text"
"github.com/ory/kratos/x"
"github.com/ory/x/otelx"
)

var (
Expand Down Expand Up @@ -254,22 +255,6 @@ func (e *WebHook) ExecuteSettingsPrePersistHook(_ http.ResponseWriter, req *http
}

func (e *WebHook) execute(ctx context.Context, data *templateContext) error {
span := trace.SpanFromContext(ctx)
attrs := map[string]string{
"webhook.http.method": data.RequestMethod,
"webhook.http.url": data.RequestURL,
"webhook.http.headers": fmt.Sprintf("%#v", data.RequestHeaders),
}

if data.Identity != nil {
attrs["webhook.identity.id"] = data.Identity.ID.String()
} else {
attrs["webhook.identity.id"] = ""
}

span.SetAttributes(otelx.StringAttrs(attrs)...)
defer span.End()

builder, err := request.NewBuilder(e.conf, e.deps)
if err != nil {
return err
Expand All @@ -282,34 +267,60 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error {
return err
}

errChan := make(chan error, 1)
attrs := semconv.HTTPClientAttributesFromHTTPRequest(req.Request)
if data.Identity != nil {
attrs = append(attrs,
attribute.String("webhook.identity.id", data.Identity.ID.String()),
attribute.String("webhook.identity.nid", data.Identity.NID.String()),
)
}
var (
httpClient = e.deps.HTTPClient(ctx)
async = gjson.GetBytes(e.conf, "response.ignore").Bool()
parseResponse = gjson.GetBytes(e.conf, "can_interrupt").Bool()
tracer = trace.SpanFromContext(ctx).TracerProvider().Tracer("kratos-webhooks")
cancel context.CancelFunc = func() {}
spanOpts = []trace.SpanStartOption{trace.WithAttributes(attrs...)}
errChan = make(chan error, 1)
)
if async {
// dissociate the context from the one passed into this function
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Minute)
spanOpts = append(spanOpts, trace.WithNewRoot())
}
ctx, span := tracer.Start(ctx, "Webhook", spanOpts...)
e.deps.Logger().WithRequest(req.Request).Info("Dispatching webhook")
t0 := time.Now()
go func() {
defer close(errChan)
defer cancel()
defer span.End()

resp, err := e.deps.HTTPClient(ctx).Do(req.WithContext(ctx))
resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
span.SetStatus(codes.Error, err.Error())
errChan <- errors.WithStack(err)
return
}
defer resp.Body.Close()
span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...)

if resp.StatusCode >= http.StatusBadRequest {
if gjson.GetBytes(e.conf, "can_interrupt").Bool() {
span.SetStatus(codes.Error, "HTTP status code >= 400")
if parseResponse {
if err := parseWebhookResponse(resp); err != nil {
span.SetStatus(codes.Error, err.Error())
errChan <- err
}
}
errChan <- fmt.Errorf("web hook failed with status code %v", resp.StatusCode)
span.SetStatus(codes.Error, fmt.Sprintf("web hook failed with status code %v", resp.StatusCode))
errChan <- fmt.Errorf("webhook failed with status code %v", resp.StatusCode)
return
}

errChan <- nil
}()

if gjson.GetBytes(e.conf, "response.ignore").Bool() {
if async {
traceID, spanID := span.SpanContext().TraceID(), span.SpanContext().SpanID()
go func() {
if err := <-errChan; err != nil {
Expand Down
20 changes: 9 additions & 11 deletions selfservice/hook/web_hook_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -842,10 +842,7 @@ func TestDisallowPrivateIPRanges(t *testing.T) {
}

func TestAsyncWebhook(t *testing.T) {
conf, reg := internal.NewFastRegistryWithMocks(t)
_ = conf
// conf.MustSet(ctx, config.ViperKeyClientHTTPNoPrivateIPRanges, true)
// conf.MustSet(ctx, config.ViperKeyClientHTTPPrivateIPExceptionURLs, []string{webhookReceiver.URL})
_, reg := internal.NewFastRegistryWithMocks(t)
logger := logrusx.New("kratos", "test")
logHook := new(test.Hook)
logger.Logger.Hooks.Add(logHook)
Expand All @@ -866,6 +863,7 @@ func TestAsyncWebhook(t *testing.T) {
}
incomingCtx, incomingCancel := context.WithCancel(context.Background())
if deadline, ok := t.Deadline(); ok {
// cancel this context one second before test timeout for clean shutdown
var cleanup context.CancelFunc
incomingCtx, cleanup = context.WithDeadline(incomingCtx, deadline.Add(-time.Second))
defer cleanup()
Expand All @@ -881,7 +879,6 @@ func TestAsyncWebhook(t *testing.T) {
w.Write([]byte("ok"))
}))
t.Cleanup(webhookReceiver.Close)
// defer webhookReceiver.Close()

wh := hook.NewWebHook(&whDeps, json.RawMessage(fmt.Sprintf(`
{
Expand All @@ -902,7 +899,7 @@ func TestAsyncWebhook(t *testing.T) {
}
// at this point, a goroutine is in the middle of the call to our test handler and waiting for a response
incomingCancel() // simulate the incoming Kratos request having finished
testFor := time.After(200 * time.Millisecond)
timeout := time.After(200 * time.Millisecond)
for done := false; !done; {
if last := logHook.LastEntry(); last != nil {
msg, err := last.String()
Expand All @@ -911,24 +908,25 @@ func TestAsyncWebhook(t *testing.T) {
}

select {
case <-testFor:
case <-timeout:
done = true
case <-time.After(50 * time.Millisecond):
// continue loop
}
}
logHook.Reset()
close(blockHandlerOnExit)
testFor = time.After(200 * time.Millisecond)
for done := false; !done; {
timeout = time.After(200 * time.Millisecond)
for {
if last := logHook.LastEntry(); last != nil {
msg, err := last.String()
require.NoError(t, err)
assert.Contains(t, msg, "Webhook request succeeded")
break
}
select {
case <-testFor:
done = true
case <-timeout:
t.Fatal("timed out waiting for successful webhook completion")
case <-time.After(50 * time.Millisecond):
// continue loop
}
Expand Down

0 comments on commit 134295e

Please sign in to comment.