Skip to content

Commit

Permalink
Propagate async flow / exceptions via Task
Browse files Browse the repository at this point in the history
  • Loading branch information
ShortDevelopment committed Apr 3, 2024
1 parent 8ab3360 commit 9e70da9
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 98 deletions.
22 changes: 19 additions & 3 deletions lib/ShortDev.Microsoft.ConnectedDevices/CdpLog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,48 @@
using ShortDev.Microsoft.ConnectedDevices.Messages.Connection;
using ShortDev.Microsoft.ConnectedDevices.Messages.Control;
using ShortDev.Microsoft.ConnectedDevices.Transports;
using System;
using System.Collections.Generic;

namespace ShortDev.Microsoft.ConnectedDevices;

internal static partial class CdpLog
{
#region Advertising
[LoggerMessage(EventId = 101, Level = LogLevel.Information, Message = "Advertising started")]
public static partial void AdvertisingStarted(this ILogger logger);

[LoggerMessage(EventId = 102, Level = LogLevel.Information, Message = "Advertising stopped")]
public static partial void AdvertisingStopped(this ILogger logger);

[LoggerMessage(EventId = 106, Level = LogLevel.Information, Message = "Error while advertising")]
public static partial void AdvertisingError(this ILogger logger, Exception ex);
#endregion

#region Listening
[LoggerMessage(EventId = 103, Level = LogLevel.Information, Message = "Listening started")]
public static partial void ListeningStarted(this ILogger logger);

[LoggerMessage(EventId = 104, Level = LogLevel.Information, Message = "Listening stopped")]
public static partial void ListeningStopped(this ILogger logger);

[LoggerMessage(EventId = 107, Level = LogLevel.Information, Message = "Error while listening")]
public static partial void ListeningError(this ILogger logger, Exception ex);
#endregion

#region Discovery
[LoggerMessage(EventId = 108, Level = LogLevel.Information, Message = "Discovery started on {TransportTypes}")]
public static partial void DiscoveryStarted(this ILogger logger, IEnumerable<CdpTransportType> transportTypes);

[LoggerMessage(EventId = 109, Level = LogLevel.Information, Message = "Discovery stopped")]
public static partial void DiscoveryStopped(this ILogger logger);

[LoggerMessage(EventId = 110, Level = LogLevel.Information, Message = "Error during discovery")]
public static partial void DiscoveryError(this ILogger logger, Exception ex);
#endregion

[LoggerMessage(EventId = 105, Level = LogLevel.Information, Message = "New socket from endpoint {Endpoint}")]
public static partial void NewSocket(this ILogger logger, EndpointInfo endpoint);



[LoggerMessage(EventId = 201, Level = LogLevel.Error, Message = "Exception in session {SessionId:X}")]
public static partial void ExceptionInSession(this ILogger logger, Exception ex, ulong sessionId);

Expand Down
121 changes: 72 additions & 49 deletions lib/ShortDev.Microsoft.ConnectedDevices/ConnectedDevicesPlatform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ public sealed class ConnectedDevicesPlatform(LocalDeviceInfo deviceInfo, ILogger
readonly ILogger<ConnectedDevicesPlatform> _logger = loggerFactory.CreateLogger<ConnectedDevicesPlatform>();

#region Transport
readonly ConcurrentDictionary<Type, ICdpTransport> _transports = new();
readonly ConcurrentDictionary<Type, ICdpTransport> _transportMap = new();
public void AddTransport<T>(T transport) where T : ICdpTransport
{
_transports.AddOrUpdate(typeof(T), transport, (_, old) =>
_transportMap.AddOrUpdate(typeof(T), transport, (_, old) =>
{
old.Dispose();
return transport;
});
}

public T? TryGetTransport<T>() where T : ICdpTransport
=> (T?)_transports.GetValueOrDefault(typeof(T));
=> (T?)_transportMap.GetValueOrDefault(typeof(T));

public ICdpTransport? TryGetTransport(CdpTransportType transportType)
=> _transports.Values.SingleOrDefault(transport => transport.TransportType == transportType);
=> _transportMap.Values.SingleOrDefault(transport => transport.TransportType == transportType);
#endregion

#region Host
Expand All @@ -43,18 +43,21 @@ public async void Advertise(CancellationToken cancellationToken)
using var isAdvertising = IsAdvertising.Lock();

_logger.AdvertisingStarted();

foreach (var (_, transport) in _transports)
try
{
if (transport is not ICdpDiscoverableTransport discoverableTransport)
continue;

discoverableTransport.Advertise(DeviceInfo, cancellationToken);
await Task.WhenAll(_transportMap.Values
.OfType<ICdpDiscoverableTransport>()
.Select(x => x.Advertise(DeviceInfo, cancellationToken))
).ConfigureAwait(false);
}
catch (Exception ex)
{
_logger.AdvertisingError(ex);
}
finally
{
_logger.AdvertisingStopped();
}

await cancellationToken.AwaitCancellation();

_logger.AdvertisingStopped();
}
#endregion

Expand All @@ -65,21 +68,31 @@ public async void Listen(CancellationToken cancellationToken)
using var isListening = IsListening.Lock();

_logger.ListeningStarted();

foreach (var (_, transport) in _transports)
try
{
transport.Listen(cancellationToken);
transport.DeviceConnected += OnDeviceConnected;
await Task.WhenAll(_transportMap.Values
.Select(async transport =>
{
transport.DeviceConnected += OnDeviceConnected;
try
{
await transport.Listen(cancellationToken).ConfigureAwait(false);
}
finally
{
transport.DeviceConnected -= OnDeviceConnected;
}
})
).ConfigureAwait(false);
}

await cancellationToken.AwaitCancellation();

foreach (var (_, transport) in _transports)
catch (Exception ex)
{
transport.DeviceConnected -= OnDeviceConnected;
_logger.ListeningError(ex);
}
finally
{
_logger.ListeningStopped();
}

_logger.ListeningStopped();
}

private void OnDeviceConnected(ICdpTransport sender, CdpSocket socket)
Expand All @@ -98,30 +111,39 @@ public async void Discover(CancellationToken cancellationToken)
{
using var isDiscovering = IsDiscovering.Lock();

foreach (var (_, transport) in _transports)
_logger.DiscoveryStarted(_transportMap.Values.Select(x => x.TransportType));
try
{
if (transport is not ICdpDiscoverableTransport discoverableTransport)
continue;

discoverableTransport.Discover(cancellationToken);
discoverableTransport.DeviceDiscovered += DeviceDiscovered;
await Task.WhenAll(_transportMap.Values
.OfType<ICdpDiscoverableTransport>()
.Select(async transport =>
{
transport.DeviceDiscovered += DeviceDiscovered;
try
{
await transport.Discover(cancellationToken).ConfigureAwait(false);
}
finally
{
transport.DeviceDiscovered -= DeviceDiscovered;
}
})
).ConfigureAwait(false);
}

await cancellationToken.AwaitCancellation();

foreach (var (_, transport) in _transports)
catch (Exception ex)
{
if (transport is not ICdpDiscoverableTransport discoverableTransport)
continue;

discoverableTransport.DeviceDiscovered -= DeviceDiscovered;
_logger.DiscoveryError(ex);
}
finally
{
_logger.DiscoveryStopped();
}
}

public async Task<CdpSession> ConnectAsync(EndpointInfo endpoint)
public async Task<CdpSession> ConnectAsync([NotNull] EndpointInfo endpoint)
{
var socket = await CreateSocketAsync(endpoint);
return await CdpSession.ConnectClientAsync(this, socket);
var socket = await CreateSocketAsync(endpoint).ConfigureAwait(false);
return await CdpSession.ConnectClientAsync(this, socket).ConfigureAwait(false);
}

internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint)
Expand All @@ -130,7 +152,7 @@ internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint)
return knownSocket;

var transport = TryGetTransport(endpoint.TransportType) ?? throw new InvalidOperationException($"No single transport found for type {endpoint.TransportType}");
var socket = await transport.ConnectAsync(endpoint);
var socket = await transport.ConnectAsync(endpoint).ConfigureAwait(false);
ReceiveLoop(socket);
return socket;
}
Expand All @@ -144,7 +166,7 @@ internal async Task<CdpSocket> CreateSocketAsync(EndpointInfo endpoint)
if (transport == null)
return null;

var socket = await transport.TryConnectAsync(endpoint, connectTimeout);
var socket = await transport.TryConnectAsync(endpoint, connectTimeout).ConfigureAwait(false);
if (socket == null)
return null;

Expand Down Expand Up @@ -239,7 +261,7 @@ bool TryGetKnownSocket(EndpointInfo endpoint, [MaybeNullWhen(false)] out CdpSock
public CdpDeviceInfo GetCdpDeviceInfo()
{
List<EndpointInfo> endpoints = [];
foreach (var (_, transport) in _transports)
foreach (var (_, transport) in _transportMap)
{
try
{
Expand All @@ -257,17 +279,18 @@ public ILogger<T> CreateLogger<T>()
public void Dispose()
{
Extensions.DisposeAll(
_transports.Select(x => x.Value),
_transportMap.Select(x => x.Value),
_knownSockets.Select(x => x.Value)
);

_transports.Clear();
_transportMap.Clear();
_knownSockets.Clear();
}

public static X509Certificate2 CreateDeviceCertificate(CdpEncryptionParams encryptionParams)
public static X509Certificate2 CreateDeviceCertificate([NotNull] CdpEncryptionParams encryptionParams)
{
CertificateRequest certRequest = new("CN=Ms-Cdp", ECDsa.Create(encryptionParams.Curve), HashAlgorithmName.SHA256);
using var key = ECDsa.Create(encryptionParams.Curve);
CertificateRequest certRequest = new("CN=Ms-Cdp", key, HashAlgorithmName.SHA256);
return certRequest.CreateSelfSigned(DateTimeOffset.Now, DateTimeOffset.Now.AddYears(5));
}

Expand Down
16 changes: 5 additions & 11 deletions lib/ShortDev.Microsoft.ConnectedDevices/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Buffers;
using System.Diagnostics.CodeAnalysis;
using System.Net.NetworkInformation;
using System.Security.Cryptography;
using System.Threading;
using System.Threading.Tasks;

namespace ShortDev.Microsoft.ConnectedDevices;

public static class Extensions
{
public static uint HighValue(this ulong value)
Expand All @@ -26,8 +20,8 @@ public static Task AwaitCancellation(this CancellationToken @this)

public static async Task<T?> WithTimeout<T>(this Task<T> task, TimeSpan timeout)
{
if (await Task.WhenAny(task, Task.Delay(timeout)) == task)
return task.Result;
if (await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false) == task)
return task!.Result;
return default;
}

Expand All @@ -37,7 +31,7 @@ public static string ToStringFormatted(this PhysicalAddress @this)
public static void DisposeAll(params IEnumerable<IDisposable>[] disposables)
=> disposables.SelectMany(x => x).DisposeAll();

public static void DisposeAll<T>(this IEnumerable<T> disposables) where T : IDisposable
public static void DisposeAll<T>([NotNull] this IEnumerable<T> disposables) where T : IDisposable
{
List<Exception> exceptions = [];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ public sealed class BluetoothTransport(IBluetoothHandler handler) : ICdpTranspor
public IBluetoothHandler Handler { get; } = handler;

public event DeviceConnectedEventHandler? DeviceConnected;
public void Listen(CancellationToken cancellationToken)
public async Task Listen(CancellationToken cancellationToken)
{
_ = Handler.ListenRfcommAsync(
await Handler.ListenRfcommAsync(
new RfcommOptions()
{
ServiceId = Constants.RfcommServiceId,
Expand All @@ -27,9 +27,9 @@ public async Task<CdpSocket> ConnectAsync(EndpointInfo endpoint)
SocketConnected = (socket) => DeviceConnected?.Invoke(this, socket)
});

public void Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellationToken)
public async Task Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellationToken)
{
_ = Handler.AdvertiseBLeBeaconAsync(
await Handler.AdvertiseBLeBeaconAsync(
new AdvertiseOptions()
{
ManufacturerId = Constants.BLeBeaconManufacturerId,
Expand All @@ -40,9 +40,9 @@ public void Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellation
}

public event DeviceDiscoveredEventHandler? DeviceDiscovered;
public void Discover(CancellationToken cancellationToken)
public async Task Discover(CancellationToken cancellationToken)
{
_ = Handler.ScanBLeAsync(new()
await Handler.ScanBLeAsync(new()
{
OnDeviceDiscovered = (advertisement, rssi) =>
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
using System.Threading;

namespace ShortDev.Microsoft.ConnectedDevices.Transports;

namespace ShortDev.Microsoft.ConnectedDevices.Transports;
public interface ICdpDiscoverableTransport : ICdpTransport
{
void Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellationToken);
Task Advertise(LocalDeviceInfo deviceInfo, CancellationToken cancellationToken);
event DeviceDiscoveredEventHandler? DeviceDiscovered;
void Discover(CancellationToken cancellationToken);
Task Discover(CancellationToken cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ public interface ICdpTransport : IDisposable
{
try
{
return await ConnectAsync(endpoint).WithTimeout(timeout);
return await ConnectAsync(endpoint).WithTimeout(timeout).ConfigureAwait(false);
}
catch { }
return null;
}

public event DeviceConnectedEventHandler? DeviceConnected;
void Listen(CancellationToken cancellationToken);
Task Listen(CancellationToken cancellationToken);
}
Loading

0 comments on commit 9e70da9

Please sign in to comment.