Skip to content

Commit

Permalink
Add support for subscribing to callbacks that return Task
Browse files Browse the repository at this point in the history
  • Loading branch information
xPaw committed Jan 4, 2025
1 parent fcd679e commit 6cc713b
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 8 deletions.
55 changes: 53 additions & 2 deletions SteamKit2/SteamKit2/Steam/SteamClient/CallbackMgr/CallbackMgr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void RunWaitCallbacks()
public async Task RunWaitCallbackAsync( CancellationToken cancellationToken = default )
{
var call = await client.WaitForCallbackAsync( cancellationToken );
Handle( call );
await HandleAsync( call );
}

/// <summary>
Expand Down Expand Up @@ -139,6 +139,36 @@ public IDisposable Subscribe<TCallback>( Action<TCallback> callbackFunc )
return Subscribe( JobID.Invalid, callbackFunc );
}

/// <summary>
/// Registers the provided <see cref="Func{T, Task}"/> to receive callbacks of type <typeparamref name="TCallback" />.
/// </summary>
/// <param name="jobID">The <see cref="JobID"/> of the callbacks that should be subscribed to.
/// If this is <see cref="JobID.Invalid"/>, all callbacks of type <typeparamref name="TCallback" /> will be received.</param>
/// <param name="callbackFunc">The function to invoke with the callback.</param>
/// <typeparam name="TCallback">The type of callback to subscribe to.</typeparam>
/// <remarks>When subscribing to asynchronous methods, <see cref="RunWaitCallbackAsync"/> should be used for awaiting callbacks.</remarks>
/// <returns>An <see cref="IDisposable"/>. Disposing of the return value will unsubscribe the <paramref name="callbackFunc"/>.</returns>
public IDisposable Subscribe<TCallback>( JobID jobID, Func<TCallback, Task> callbackFunc ) where TCallback : CallbackMsg
{
ArgumentNullException.ThrowIfNull( jobID );
ArgumentNullException.ThrowIfNull( callbackFunc );

var callback = new Internal.AsyncCallback<TCallback>( callbackFunc, this, jobID );
return callback;
}

/// <summary>
/// Registers the provided <see cref="Func{T, Task}"/> to receive callbacks of type <typeparam name="TCallback" />.
/// </summary>
/// <param name="callbackFunc">The function to invoke with the callback.</param>
/// <remarks>When subscribing to asynchronous methods, <see cref="RunWaitCallbackAsync"/> should be used for awaiting callbacks.</remarks>
/// <returns>An <see cref="IDisposable"/>. Disposing of the return value will unsubscribe the <paramref name="callbackFunc"/>.</returns>
public IDisposable Subscribe<TCallback>( Func<TCallback, Task> callbackFunc )
where TCallback : CallbackMsg
{
return Subscribe( JobID.Invalid, callbackFunc );
}

/// <summary>
/// Registers the provided <see cref="Action{T}"/> to receive callbacks for notifications from the service of type <typeparam name="TService" />
/// with the notification message of type <typeparam name="TNotification"></typeparam>.
Expand Down Expand Up @@ -191,7 +221,28 @@ void Handle( CallbackMsg call )
{
if ( callback.CallbackType.IsAssignableFrom( type ) )
{
callback.Run( call );
var task = callback.Run( call );
task?.Wait();
}
}
}

async Task HandleAsync( CallbackMsg call )
{
var callbacks = registeredCallbacks;
var type = call.GetType();

// find handlers interested in this callback
foreach ( var callback in callbacks )
{
if ( callback.CallbackType.IsAssignableFrom( type ) )
{
var task = callback.Run( call );

if ( task != null )
{
await task;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


using System;
using System.Threading.Tasks;

namespace SteamKit2.Internal
{
Expand All @@ -16,22 +17,25 @@ namespace SteamKit2.Internal
abstract class CallbackBase
{
internal abstract Type CallbackType { get; }
internal abstract void Run( CallbackMsg callback );
internal abstract Task? Run( CallbackMsg callback );
internal bool IsAsync { get; init; }
}

sealed class Callback<TCall> : CallbackBase, IDisposable
where TCall : CallbackMsg
{
CallbackManager? mgr;

public JobID JobID { get; set; }
public JobID JobID { get; }

public Action<TCall> OnRun { get; set; }
public Action<TCall> OnRun { get; }

internal override Type CallbackType => typeof( TCall );

public Callback( Action<TCall> func, CallbackManager mgr, JobID jobID )
{
ArgumentNullException.ThrowIfNull( func );

this.JobID = jobID;
this.OnRun = func;
this.mgr = mgr;
Expand All @@ -52,13 +56,61 @@ public void Dispose()
System.GC.SuppressFinalize( this );
}

internal override void Run( CallbackMsg callback )
internal override Task? Run( CallbackMsg callback )
{
var cb = callback as TCall;
if ( cb != null && ( cb.JobID == JobID || JobID == JobID.Invalid ) && OnRun != null )
if ( cb != null && ( cb.JobID == JobID || JobID == JobID.Invalid ) )
{
OnRun( cb );
}
return null;
}
}

sealed class AsyncCallback<TCall> : CallbackBase, IDisposable
where TCall : CallbackMsg
{
CallbackManager? mgr;

public JobID JobID { get; }

public Func<TCall, Task> OnRun { get; }

internal override Type CallbackType => typeof( TCall );

public AsyncCallback( Func<TCall, Task> func, CallbackManager mgr, JobID jobID )
{
ArgumentNullException.ThrowIfNull( func );

this.IsAsync = true;
this.JobID = jobID;
this.OnRun = func;
this.mgr = mgr;

mgr.Register( this );
}

~AsyncCallback()
{
Dispose();
}

public void Dispose()
{
mgr?.Unregister( this );
mgr = null;

System.GC.SuppressFinalize( this );
}

internal override Task Run( CallbackMsg callback )
{
var cb = callback as TCall;
if ( cb != null && ( cb.JobID == JobID || JobID == JobID.Invalid ) )
{
return OnRun( cb );
}
return Task.CompletedTask;
}
}
}
34 changes: 33 additions & 1 deletion SteamKit2/Tests/CallbackManagerFacts.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using SteamKit2;
using Xunit;
Expand Down Expand Up @@ -257,7 +258,7 @@ void unsubscribe( CallbackForTest cb )
}

[Fact]
public void CorrectlysubscribesFromInsideOfCallback()
public void CorrectlySubscribesFromInsideOfCallback()
{
static void nothing( CallbackForTest cb )
{
Expand All @@ -275,6 +276,37 @@ void subscribe( CallbackForTest cb )
PostAndRunCallback( new CallbackForTest { UniqueID = Guid.NewGuid() } );
}

[Fact]
public async Task CorrectlyAwaitsForAsyncCallbacks()
{
var callback = new CallbackForTest { UniqueID = Guid.NewGuid() };

var numCallbacksRun = 0;
async Task action( CallbackForTest cb )
{
await Task.Delay( 100, TestContext.Current.CancellationToken );
Assert.Equal( callback.UniqueID, cb.UniqueID );
numCallbacksRun++;
}

using ( mgr.Subscribe<CallbackForTest>( action ) )
{
for ( var i = 0; i < 10; i++ )
{
client.PostCallback( callback );
}

for ( var i = 1; i <= 10; i++ )
{
await mgr.RunWaitCallbackAsync( TestContext.Current.CancellationToken );
Assert.Equal( i, numCallbacksRun );
}

mgr.RunWaitAllCallbacks( TimeSpan.Zero );
Assert.Equal( 10, numCallbacksRun );
}
}

void PostAndRunCallback(CallbackMsg callback)
{
client.PostCallback(callback);
Expand Down

0 comments on commit 6cc713b

Please sign in to comment.