Skip to content

Commit

Permalink
[Tracing] Fix WebRequest bug with existing distributed tracing headers (
Browse files Browse the repository at this point in the history
#4770)

When we pre-emptively add distributed tracing headers in `WebRequest.BeginGetRequestStream`, `WebRequest.BeginGetResponse`, and `WebRequest.GetRequestStream`, add the headers to a cache to signal to the later integrations that we had last updated the headers and we are responsible for the distributed tracing headers in the WebRequest object.

When we check for the pre-emptively added distributed tracing headers in `WebRequest.EndGetResponse`, `WebRequest.GetResponse` and `WebRequest.GetResponseAsync`, first check a cache to know whether we added distributed tracing headers to the WebRequest object. If present, then continue to use the injected trace context. Otherwise, we create a new context.
  • Loading branch information
zacharycmontoya authored Nov 2, 2023
1 parent 00d32ab commit 7124515
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// <copyright file="HeadersInjectedCache.cs" company="Datadog">
// Unless explicitly stated otherwise all files in this repository are licensed under the Apache 2 License.
// This product includes software developed at Datadog (https://www.datadoghq.com/). Copyright 2017 Datadog, Inc.
// </copyright>

#nullable enable

using System.Net;
using System.Runtime.CompilerServices;

namespace Datadog.Trace.ClrProfiler.AutoInstrumentation.Http.WebRequest;

internal static class HeadersInjectedCache
{
private static readonly object InjectedValue = new();
private static readonly ConditionalWeakTable<WebHeaderCollection, object> Cache = new();

public static void SetInjectedHeaders(WebHeaderCollection headers)
{
#if NETCOREAPP3_1_OR_GREATER
Cache.AddOrUpdate(headers, InjectedValue);
#else
Cache.GetValue(headers, _ => InjectedValue);
#endif
}

public static bool TryGetInjectedHeaders(WebHeaderCollection headers)
{
return Cache.TryGetValue(headers, out _);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ internal static CallTargetState OnMethodBegin<TTarget>(TTarget instance, AsyncCa
// The expected sequence of calls is GetRequestStream -> GetResponse. Headers can't be modified after calling GetRequestStream.
// At the same time, we don't want to set an active scope now, because it's possible that GetResponse will never be called.
// Instead, we generate a spancontext and inject it in the headers. GetResponse will fetch them and create an active scope with the right id.
// Additionally, add the request headers to a cache to indicate that distributed tracing headers were
// added by us, not the application
SpanContextPropagator.Instance.Inject(span.Context, request.Headers.Wrap());
HeadersInjectedCache.SetInjectedHeaders(request.Headers);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ internal static CallTargetState OnMethodBegin<TTarget>(TTarget instance, AsyncCa
// Add distributed tracing headers to the HTTP request.
// We don't want to set an active scope now, because it's possible that EndGetResponse will never be called.
// Instead, we generate a spancontext and inject it in the headers. EndGetResponse will fetch them and create an active scope with the right id.
// Additionally, add the request headers to a cache to indicate that distributed tracing headers were
// added by us, not the application
SpanContextPropagator.Instance.Inject(span.Context, request.Headers.Wrap());
HeadersInjectedCache.SetInjectedHeaders(request.Headers);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,14 @@ internal static CallTargetReturn<TReturn> OnMethodEnd<TTarget, TReturn>(TTarget
}

// Check if any headers were injected by a previous call
var existingSpanContext = SpanContextPropagator.Instance.Extract(request.Headers.Wrap());
// Since it is possible for users to manually propagate headers (which we should
// overwrite), check our cache which will be populated with header objects
// that we have injected context into
SpanContext existingSpanContext = null;
if (HeadersInjectedCache.TryGetInjectedHeaders(request.Headers))
{
existingSpanContext = SpanContextPropagator.Instance.Extract(request.Headers.Wrap());
}

// If this operation creates the trace, then we need to re-apply the sampling priority
bool setSamplingPriority = existingSpanContext?.SamplingPriority != null && Tracer.Instance.ActiveScope == null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ internal static CallTargetState OnMethodBegin<TTarget>(TTarget instance)
// The expected sequence of calls is GetRequestStream -> GetResponse. Headers can't be modified after calling GetRequestStream.
// At the same time, we don't want to set an active scope now, because it's possible that GetResponse will never be called.
// Instead, we generate a spancontext and inject it in the headers. GetResponse will fetch them and create an active scope with the right id.
// Additionally, add the request headers to a cache to indicate that distributed tracing headers were
// added by us, not the application
SpanContextPropagator.Instance.Inject(span.Context, request.Headers.Wrap());
HeadersInjectedCache.SetInjectedHeaders(request.Headers);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

using System;
using System.Net;
using System.Runtime.CompilerServices;
using Datadog.Trace.ClrProfiler.CallTarget;
using Datadog.Trace.Configuration;
using Datadog.Trace.ExtensionMethods;
Expand Down Expand Up @@ -42,7 +43,14 @@ public static CallTargetState GetResponse_OnMethodBegin<TTarget>(TTarget instanc
if (instance is HttpWebRequest request && IsTracingEnabled(request))
{
// Check if any headers were injected by a previous call to GetRequestStream
var spanContext = SpanContextPropagator.Instance.Extract(request.Headers.Wrap());
// Since it is possible for users to manually propagate headers (which we should
// overwrite), check our cache which will be populated with header objects
// that we have injected context into
SpanContext spanContext = null;
if (HeadersInjectedCache.TryGetInjectedHeaders(request.Headers))
{
spanContext = SpanContextPropagator.Instance.Extract(request.Headers.Wrap());
}

// If this operation creates the trace, then we need to re-apply the sampling priority
var tracer = Tracer.Instance;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public PeerServiceMappingTests(ITestOutputHelper output)
[Trait("SupportsInstrumentationVerification", "True")]
public void RenamesService()
{
var expectedSpanCount = 82;
var expectedSpanCount = 87;

SetInstrumentationVerification();
const string expectedOperationName = "http.request";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public ServiceMappingTests(ITestOutputHelper output)
[Trait("SupportsInstrumentationVerification", "True")]
public void RenamesService()
{
var expectedSpanCount = 82;
var expectedSpanCount = 87;

SetInstrumentationVerification();
const string expectedOperationName = "http.request";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ public void TracingDisabled_DoesNotSubmitsTraces()
private void RunTest(string metadataSchemaVersion)
{
SetInstrumentationVerification();
var expectedSpanCount = 82;
var expectedAllSpansCount = 130;
var expectedSpanCount = 87;

int httpPort = TcpPortProvider.GetOpenPort();
Output.WriteLine($"Assigning port {httpPort} for the httpPort.");
Expand All @@ -85,8 +86,10 @@ private void RunTest(string metadataSchemaVersion)
using (var agent = EnvironmentHelper.GetMockAgent())
using (ProcessResult processResult = RunSampleAndWaitForExit(agent, arguments: $"Port={httpPort}"))
{
agent.SpanFilters.Add(s => s.Type == SpanTypes.Http);
var spans = agent.WaitForSpans(expectedSpanCount).OrderBy(s => s.Start);
var allSpans = agent.WaitForSpans(expectedAllSpansCount).OrderBy(s => s.Start);
allSpans.Should().OnlyHaveUniqueItems(s => new { s.SpanId, s.TraceId });

var spans = allSpans.Where(s => s.Type == SpanTypes.Http);
spans.Should().HaveCount(expectedSpanCount);
ValidateIntegrationSpans(spans, metadataSchemaVersion, expectedServiceName: clientSpanServiceName, isExternalSpan);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Specialized;
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Text;
Expand Down Expand Up @@ -332,6 +333,24 @@ public static async Task SendWebRequestRequests(bool tracingDisabled, string url
Console.WriteLine("Received response for request.GetResponse()");
}

using (_sampleHelpers.CreateScope("GetResponseWithDistributedTracingHeaders"))
{
// Create separate request objects since .NET Core asserts only one response per request
HttpWebRequest request = (HttpWebRequest)System.Net.WebRequest.Create(GetUrlForTest("GetResponse", url));
if (tracingDisabled)
{
request.Headers.Add(TracingEnabled, "false");
}

// Test the behavior when distributed tracing headers are manually propagated
// to the outgoing HTTP span.
// They should be overridden by the tracer
request.Headers.Add(GenerateCurrentDistributedTracingHeaders());

request.GetResponse().Close();
Console.WriteLine("Received response for request.GetResponse()");
}

using (_sampleHelpers.CreateScope("GetResponseNotFound"))
{
// Create separate request objects since .NET Core asserts only one response per request
Expand Down Expand Up @@ -378,6 +397,24 @@ public static async Task SendWebRequestRequests(bool tracingDisabled, string url
}
}

using (_sampleHelpers.CreateScope("GetResponseAsyncWithDistributedTracingHeaders"))
{
// Create separate request objects since .NET Core asserts only one response per request
HttpWebRequest request = (HttpWebRequest)System.Net.WebRequest.Create(GetUrlForTest("GetResponseAsync", url));
if (tracingDisabled)
{
request.Headers.Add(TracingEnabled, "false");
}

// Test the behavior when distributed tracing headers are manually propagated
// to the outgoing HTTP span.
// They should be overridden by the tracer
request.Headers.Add(GenerateCurrentDistributedTracingHeaders());

(await request.GetResponseAsync()).Close();
Console.WriteLine("Received response for request.GetResponseAsync()");
}

using (_sampleHelpers.CreateScope("GetResponseAsync"))
{
// Create separate request objects since .NET Core asserts only one response per request
Expand Down Expand Up @@ -442,16 +479,40 @@ public static async Task SendWebRequestRequests(bool tracingDisabled, string url
GetRequestStream(tracingDisabled, url);
}

using (_sampleHelpers.CreateScope("GetRequestStreamWithDistributedTracingHeaders"))
{
// Test the behavior when distributed tracing headers are manually propagated
// to the outgoing HTTP span.
// They should be overridden by the tracer
GetRequestStream(tracingDisabled, url, GenerateCurrentDistributedTracingHeaders());
}

using (_sampleHelpers.CreateScope("BeginGetRequestStream"))
{
BeginGetRequestStream(tracingDisabled, url);
}

using (_sampleHelpers.CreateScope("BeginGetRequestStreamWithDistributedTracingHeaders"))
{
// Test the behavior when distributed tracing headers are manually propagated
// to the outgoing HTTP span.
// They should be overridden by the tracer
BeginGetRequestStream(tracingDisabled, url, GenerateCurrentDistributedTracingHeaders());
}

using (_sampleHelpers.CreateScope("BeginGetResponse"))
{
BeginGetResponse(tracingDisabled, "BeginGetResponseAsync", url);
}

using (_sampleHelpers.CreateScope("BeginGetResponseWithDistributedTracingHeaders"))
{
// Test the behavior when distributed tracing headers are manually propagated
// to the outgoing HTTP span.
// They should be overridden by the tracer
BeginGetResponse(tracingDisabled, "BeginGetResponseAsync", url, GenerateCurrentDistributedTracingHeaders());
}

using (_sampleHelpers.CreateScope("BeginGetResponseNotFound"))
{
BeginGetResponse(tracingDisabled, "BeginGetResponseNotFoundAsync", url);
Expand Down Expand Up @@ -484,7 +545,7 @@ await Task.Factory.FromAsync(

}

private static void BeginGetResponse(bool tracingDisabled, string testName, string url)
private static void BeginGetResponse(bool tracingDisabled, string testName, string url, NameValueCollection additionalHeaders = null)
{
// Create separate request objects since .NET Core asserts only one response per request
HttpWebRequest request = (HttpWebRequest)System.Net.WebRequest.Create(GetUrlForTest(testName, url));
Expand All @@ -497,6 +558,11 @@ private static void BeginGetResponse(bool tracingDisabled, string testName, stri
request.Headers.Add(TracingEnabled, "false");
}

if (additionalHeaders is not null)
{
request.Headers.Add(additionalHeaders);
}

var stream = request.GetRequestStream();
stream.Write(new byte[1], 0, 1);

Expand All @@ -522,7 +588,7 @@ private static void BeginGetResponse(bool tracingDisabled, string testName, stri
_allDone.WaitOne();
}

private static void BeginGetRequestStream(bool tracingDisabled, string url)
private static void BeginGetRequestStream(bool tracingDisabled, string url, NameValueCollection additionalHeaders = null)
{
// Create separate request objects since .NET Core asserts only one response per request
HttpWebRequest request = (HttpWebRequest)System.Net.WebRequest.Create(GetUrlForTest("BeginGetRequestStream", url));
Expand All @@ -535,6 +601,11 @@ private static void BeginGetRequestStream(bool tracingDisabled, string url)
request.Headers.Add(TracingEnabled, "false");
}

if (additionalHeaders is not null)
{
request.Headers.Add(additionalHeaders);
}

request.BeginGetRequestStream(
iar =>
{
Expand All @@ -553,7 +624,7 @@ private static void BeginGetRequestStream(bool tracingDisabled, string url)
_allDone.WaitOne();
}

private static void GetRequestStream(bool tracingDisabled, string url)
private static void GetRequestStream(bool tracingDisabled, string url, NameValueCollection additionalHeaders = null)
{
// Create separate request objects since .NET Core asserts only one response per request
HttpWebRequest request = (HttpWebRequest)System.Net.WebRequest.Create(GetUrlForTest("GetRequestStream", url));
Expand All @@ -566,6 +637,11 @@ private static void GetRequestStream(bool tracingDisabled, string url)
request.Headers.Add(TracingEnabled, "false");
}

if (additionalHeaders is not null)
{
request.Headers.Add(additionalHeaders);
}

var stream = request.GetRequestStream();
stream.Write(new byte[1], 0, 1);

Expand All @@ -577,5 +653,18 @@ private static string GetUrlForTest(string testName, string baseUrl)
{
return baseUrl + "?" + testName;
}

private static NameValueCollection GenerateCurrentDistributedTracingHeaders()
{
var current = Activity.Current;
var decimalTraceId = Convert.ToUInt64(current.TraceId.ToHexString().Substring(16, 16), 16);
var decimalSpanId = Convert.ToUInt64(current.SpanId.ToHexString(), 16);

return new NameValueCollection()
{
{ "traceparent", current.Id },
{ "tracestate", current.TraceStateString },
};
}
}
}

0 comments on commit 7124515

Please sign in to comment.