Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with potential deadlock when reading a managed stream that we wrap from native code #1126

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/Tests/UnitTest/TestComponentCSharp_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
using Windows.Security.Cryptography.Core;
using System.Reflection;
using Windows.Devices.Enumeration.Pnp;

using System.Text;

#if NET
using WeakRefNS = System;
#else
Expand Down Expand Up @@ -389,7 +390,21 @@ public void TestEmptyBufferCopyTo()
byte[] array = { };
buffer.CopyTo(array);
Assert.True(array.Length == 0);
}

#if NET
[Fact]
public async void TestRandomStreamWithContext()
{
string str = "UnitTest.TestCSharp+CustomDictionary, UnitTest, Version=1.0.0.0, Culture=neutral, PublicKeyToken=null";
var stream = new MemoryStream(Encoding.UTF8.GetBytes(str)).AsRandomAccessStream();
var buffer = new WindowsRuntimeBuffer(150);
var result = await stream.ReadAsync(buffer, 150, InputStreamOptions.None);
Assert.NotNull(result);
Assert.True(buffer.TryGetUnderlyingData(out byte[] data, out int _));
Assert.Equal(str, Encoding.UTF8.GetString(data).Trim('\0'));
}
#endif

[Fact]
public void TestTypePropertyWithSystemType()
Expand Down
27 changes: 27 additions & 0 deletions src/cswinrt/strings/additions/Windows.Foundation/AsyncInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ public static IAsyncActionWithProgress<TProgress> Run<TProgress>(Func<Cancellati
throw new ArgumentNullException(nameof(taskProvider));

return new TaskToAsyncActionWithProgressAdapter<TProgress>(taskProvider);
}

internal static IAsyncActionWithProgress<TProgress> RunWithoutCapturedContext<TProgress>(Func<CancellationToken, IProgress<TProgress>, Task> taskProvider)
{
if (taskProvider == null)
throw new ArgumentNullException(nameof(taskProvider));

return new TaskToAsyncActionWithProgressAdapter<TProgress>(taskProvider, false);
}


Expand All @@ -90,6 +98,16 @@ public static IAsyncOperation<TResult> Run<TResult>(Func<CancellationToken, Task
return new TaskToAsyncOperationAdapter<TResult>(taskProvider);
}

internal static IAsyncOperation<TResult> RunWithoutCapturedContext<TResult>(Func<CancellationToken, Task<TResult>> taskProvider)
{
// This is only internal to reduce the number of public overloads.
// Code execution flows through this method when the method above is called. We can always make this public.

if (taskProvider == null)
throw new ArgumentNullException(nameof(taskProvider));

return new TaskToAsyncOperationAdapter<TResult>(taskProvider, false);
}

/// <summary>
/// Creates and starts an <see cref="IAsyncOperationWithProgress{TResult, TProgress}"/> instance
Expand All @@ -114,6 +132,15 @@ public static IAsyncOperationWithProgress<TResult, TProgress> Run<TResult, TProg
throw new ArgumentNullException(nameof(taskProvider));

return new TaskToAsyncOperationWithProgressAdapter<TResult, TProgress>(taskProvider);
}

internal static IAsyncOperationWithProgress<TResult, TProgress> RunWithoutCapturedContext<TResult, TProgress>(
Func<CancellationToken, IProgress<TProgress>, Task<TResult>> taskProvider)
{
if (taskProvider == null)
throw new ArgumentNullException(nameof(taskProvider));

return new TaskToAsyncOperationWithProgressAdapter<TResult, TProgress>(taskProvider, false);
}

#endregion Factory methods for creating "normal" IAsyncInfo instances backed by a Task created by a pastProvider delegate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ internal sealed class TaskToAsyncActionAdapter
: TaskToAsyncInfoAdapter<AsyncActionCompletedHandler, VoidReferenceTypeParameter, VoidValueTypeParameter, VoidValueTypeParameter>,
IAsyncAction
{
internal TaskToAsyncActionAdapter(Delegate taskGenerator)
internal TaskToAsyncActionAdapter(Delegate taskGenerator, bool executeHandlersOnCapturedContext = true)

: base(taskGenerator)
: base(taskGenerator, executeHandlersOnCapturedContext)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ internal sealed class TaskToAsyncActionWithProgressAdapter<TProgress>
TProgress>,
IAsyncActionWithProgress<TProgress>
{
internal TaskToAsyncActionWithProgressAdapter(Delegate taskGenerator)
internal TaskToAsyncActionWithProgressAdapter(Delegate taskGenerator, bool executeHandlersOnCapturedContext = true)

: base(taskGenerator)
: base(taskGenerator, executeHandlersOnCapturedContext)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private static InvalidOperationException CreateCannotGetResultsFromIncompleteOpe
/// <summary>Creates an IAsyncInfo from the specified delegate. The delegate will be called to construct a task that will
/// represent the future encapsulated by this IAsyncInfo.</summary>
/// <param name="taskProvider">The task generator to use for creating the task.</param>
internal TaskToAsyncInfoAdapter(Delegate taskProvider)
internal TaskToAsyncInfoAdapter(Delegate taskProvider, bool executeHandlersOnCapturedContext = true)
{
Debug.Assert(taskProvider != null);
Debug.Assert((null != (taskProvider as Func<Task>))
Expand All @@ -139,7 +139,10 @@ internal TaskToAsyncInfoAdapter(Delegate taskProvider)

// The IAsyncInfo is reasonably expected to be created/started by the same code that wires up the Completed and Progress handlers.
// Record the current SynchronizationContext so that we can invoke completion and progress callbacks in it later.
_startingContext = GetStartingContext();
if (executeHandlersOnCapturedContext)
{
_startingContext = GetStartingContext();
}

// Construct task from the specified provider:
Task task = InvokeTaskProvider(taskProvider);
Expand Down Expand Up @@ -171,7 +174,9 @@ internal TaskToAsyncInfoAdapter(Delegate taskProvider)
/// <param name="underlyingProgressDispatcher">A progress listener/pugblisher that receives progress notifications
/// form <code>underlyingTask</code>.</param>
internal TaskToAsyncInfoAdapter(Task underlyingTask,
CancellationTokenSource underlyingCancelTokenSource, Progress<TProgressInfo> underlyingProgressDispatcher)
CancellationTokenSource underlyingCancelTokenSource,
Progress<TProgressInfo> underlyingProgressDispatcher,
bool executeHandlersOnCapturedContext = true)
{
if (underlyingTask == null)
throw new ArgumentNullException(nameof(underlyingTask));
Expand All @@ -182,7 +187,10 @@ internal TaskToAsyncInfoAdapter(Task underlyingTask,

// The IAsyncInfo is reasonably expected to be created/started by the same code that wires up the Completed and Progress handlers.
// Record the current SynchronizationContext so that we can invoke completion and progress callbacks in it later.
_startingContext = GetStartingContext();
if (executeHandlersOnCapturedContext)
{
_startingContext = GetStartingContext();
}

// We do not need to invoke any delegates to get the task, it is provided for us:
_dataContainer = underlyingTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ internal sealed class TaskToAsyncOperationAdapter<TResult>
: TaskToAsyncInfoAdapter<AsyncOperationCompletedHandler<TResult>, VoidReferenceTypeParameter, TResult, VoidValueTypeParameter>,
IAsyncOperation<TResult>
{
internal TaskToAsyncOperationAdapter(Delegate taskGenerator)
internal TaskToAsyncOperationAdapter(Delegate taskGenerator, bool executeHandlersOnCapturedContext = true)

: base(taskGenerator)
: base(taskGenerator, executeHandlersOnCapturedContext)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ internal sealed class TaskToAsyncOperationWithProgressAdapter<TResult, TProgress
TProgress>,
IAsyncOperationWithProgress<TResult, TProgress>
{
internal TaskToAsyncOperationWithProgressAdapter(Delegate taskGenerator)
internal TaskToAsyncOperationWithProgressAdapter(Delegate taskGenerator, bool executeHandlersOnCapturedContext = true)

: base(taskGenerator)
: base(taskGenerator, executeHandlersOnCapturedContext)
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,12 @@ internal static IAsyncOperationWithProgress<IBuffer, uint> ReadAsync_AbstractStr

// If we got here, then no error was detected. Return the results buffer:
return dataBuffer;
}; // readOperation

return AsyncInfo.Run<IBuffer, uint>(readOperation);
}; // readOperation

// Construct and run the async operation. Any registered handlers for completion / progress
// on this async operation don't run on the captured context if any to avoid deadlock issues when
// CreateStreamOverRandomAccessStream is used to create a wrapper around this stream on a STA thread.
return AsyncInfo.RunWithoutCapturedContext<IBuffer, uint>(readOperation);
} // ReadAsync_AbstractStream

#endregion ReadAsync implementations
Expand Down Expand Up @@ -216,8 +219,10 @@ internal static IAsyncOperationWithProgress<uint, uint> WriteAsync_AbstractStrea
};
} // if-else

// Construct and run the async operation:
return AsyncInfo.Run<uint, uint>(writeOperation);
// Construct and run the async operation. Any registered handlers for completion / progress
// on this async operation don't run on the captured context if any to avoid deadlock issues when
// CreateStreamOverRandomAccessStream is used to create a wrapper around this stream on a STA thread.
return AsyncInfo.RunWithoutCapturedContext<uint, uint>(writeOperation);
} // WriteAsync_AbstractStream

#endregion WriteAsync implementations
Expand All @@ -239,8 +244,10 @@ internal static IAsyncOperation<bool> FlushAsync_AbstractStream(Stream stream)
return true;
};

// Construct and run the async operation:
return AsyncInfo.Run<bool>(flushOperation);
// Construct and run the async operation. Any registered handlers for completion / progress
// on this async operation don't run on the captured context if any to avoid deadlock issues when
// CreateStreamOverRandomAccessStream is used to create a wrapper around this stream on a STA thread.
return AsyncInfo.RunWithoutCapturedContext<bool>(flushOperation);
}
#endregion FlushAsync implementations

Expand Down