From 6f22ca9bd2521a796a066f40261ed25c14ddadeb Mon Sep 17 00:00:00 2001 From: Mickey Reiss Date: Tue, 8 Sep 2020 17:48:02 -0500 Subject: [PATCH] contrib/twitchtv/twirp: Add explicit reference to the request span --- contrib/twitchtv/twirp/twirp.go | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/contrib/twitchtv/twirp/twirp.go b/contrib/twitchtv/twirp/twirp.go index dae161b840..e4382836c0 100644 --- a/contrib/twitchtv/twirp/twirp.go +++ b/contrib/twitchtv/twirp/twirp.go @@ -21,10 +21,9 @@ import ( "github.com/twitchtv/twirp" ) -type contextKey int - -const ( - twirpErrorKey contextKey = iota +type ( + twirpErrorKey struct{} + twirpSpanKey struct{} ) // HTTPClient is duplicated from twirp's generated service code. @@ -167,15 +166,23 @@ func requestReceivedHook(cfg *config) func(context.Context) (context.Context, er if !math.IsNaN(cfg.analyticsRate) { opts = append(opts, tracer.Tag(ext.EventSampleRate, cfg.analyticsRate)) } - _, ctx = tracer.StartSpanFromContext(ctx, spanNameFromContext(ctx), opts...) + span, ctx := tracer.StartSpanFromContext(ctx, spanNameFromContext(ctx), opts...) + + ctx = context.WithValue(ctx, twirpSpanKey{}, span) return ctx, nil } } func requestRoutedHook(cfg *config) func(context.Context) (context.Context, error) { return func(ctx context.Context) (context.Context, error) { - span, ok := tracer.SpanFromContext(ctx) + maybeSpan := ctx.Value(twirpSpanKey{}) + if maybeSpan == nil { + log.Error("contrib/twitchtv/twirp.requestRoutedHook: found no span in context") + return ctx, nil + } + span, ok := maybeSpan.(tracer.Span) if !ok { + log.Error("contrib/twitchtv/twirp.requestRoutedHook: found invalid span type in context") return ctx, nil } if method, ok := twirp.MethodName(ctx); ok { @@ -194,20 +201,24 @@ func responsePreparedHook(cfg *config) func(context.Context) context.Context { func responseSentHook(cfg *config) func(context.Context) { return func(ctx context.Context) { - span, ok := tracer.SpanFromContext(ctx) + maybeSpan := ctx.Value(twirpSpanKey{}) + if maybeSpan == nil { + return + } + span, ok := maybeSpan.(tracer.Span) if !ok { return } if sc, ok := twirp.StatusCode(ctx); ok { span.SetTag(ext.HTTPCode, sc) } - err, _ := ctx.Value(twirpErrorKey).(twirp.Error) + err, _ := ctx.Value(twirpErrorKey{}).(twirp.Error) span.Finish(tracer.WithError(err)) } } func errorHook(cfg *config) func(context.Context, twirp.Error) context.Context { return func(ctx context.Context, err twirp.Error) context.Context { - return context.WithValue(ctx, twirpErrorKey, err) + return context.WithValue(ctx, twirpErrorKey{}, err) } }