diff --git a/mqttclients/dotnet/MQTTnet.Client.Extensions/WithJWT.cs b/mqttclients/dotnet/MQTTnet.Client.Extensions/WithJWT.cs index 7165b76..47478f6 100644 --- a/mqttclients/dotnet/MQTTnet.Client.Extensions/WithJWT.cs +++ b/mqttclients/dotnet/MQTTnet.Client.Extensions/WithJWT.cs @@ -4,19 +4,21 @@ namespace MQTTnet.Client.Extensions { public static partial class MqttNetExtensions { - static Func _getTokenCallBack = null!; + static Func<(byte[], TimeSpan)> _getTokenCallBack = null!; static Timer _refreshTimer = null!; - public static MqttClientOptionsBuilder WithJWT(this MqttClientOptionsBuilder builder, MqttConnectionSettings cs, Func getTokenCallBack, IMqttClient mqttClient, TimeSpan refreshPeriod) + public static MqttClientOptionsBuilder WithJWT(this MqttClientOptionsBuilder builder, MqttConnectionSettings cs, Func<(byte[], TimeSpan)> getTokenCallBack, IMqttClient mqttClient) { _getTokenCallBack = getTokenCallBack; + (byte[] token, TimeSpan ts) = getTokenCallBack(); + Trace.TraceInformation($"Token expires in {ts.TotalSeconds} seconds"); builder .WithConnectionSettings(cs) - .WithAuthentication("OAUTH2-JWT", getTokenCallBack()); + .WithAuthentication("OAUTH2-JWT", token); - _refreshTimer = new Timer(RefreshToken, mqttClient, 0, Convert.ToInt32(refreshPeriod.TotalMilliseconds)); + _refreshTimer = new Timer(RefreshToken, mqttClient, 0, (int)ts.TotalSeconds * 1000); return builder; } @@ -30,7 +32,7 @@ static void RefreshToken(object? state) await mqttClient.SendExtendedAuthenticationExchangeDataAsync( new MqttExtendedAuthenticationExchangeData() { - AuthenticationData = _getTokenCallBack(), + AuthenticationData = _getTokenCallBack().Item1, ReasonCode = MQTTnet.Protocol.MqttAuthenticateReasonCode.ReAuthenticate }); }); diff --git a/scenarios/jwt_authentication/dotnet/jwt_authentication/Program.cs b/scenarios/jwt_authentication/dotnet/jwt_authentication/Program.cs index e36570e..110f02b 100644 --- a/scenarios/jwt_authentication/dotnet/jwt_authentication/Program.cs +++ b/scenarios/jwt_authentication/dotnet/jwt_authentication/Program.cs @@ -12,7 +12,7 @@ IMqttClient mqttClient = new MqttFactory().CreateMqttClient(MqttNetTraceLogger.CreateTraceLogger()); MqttClientConnectResult connAck = await mqttClient!.ConnectAsync(new MqttClientOptionsBuilder() - .WithJWT(cs, GetToken, mqttClient, TimeSpan.FromHours(1)) + .WithJWT(cs, GetToken, mqttClient) .Build()); Console.WriteLine($"Client Connected: {mqttClient.IsConnected} with CONNACK: {connAck.ResultCode} with auth method {mqttClient.Options.AuthenticationMethod}"); @@ -33,10 +33,10 @@ await Task.Delay(10000); } -static byte[] GetToken() +static (byte[], TimeSpan) GetToken() { DefaultAzureCredential defaultCredential = new(); Console.WriteLine($"---- Get Token {DateTime.Now.ToString("o")} ----"); AccessToken jwt = defaultCredential.GetToken(new TokenRequestContext(new string[] { "https://eventgrid.azure.net/.default" })); - return Encoding.UTF8.GetBytes(jwt.Token); + return (Encoding.UTF8.GetBytes(jwt.Token), jwt.ExpiresOn.Subtract(DateTime.UtcNow)); } \ No newline at end of file