Skip to content

Commit

Permalink
get the refresh period from the token
Browse files Browse the repository at this point in the history
  • Loading branch information
rido-min committed Jun 28, 2024
1 parent e0411ee commit fd15f58
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
12 changes: 7 additions & 5 deletions mqttclients/dotnet/MQTTnet.Client.Extensions/WithJWT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@ namespace MQTTnet.Client.Extensions
{
public static partial class MqttNetExtensions
{
static Func<byte[]> _getTokenCallBack = null!;
static Func<(byte[], TimeSpan)> _getTokenCallBack = null!;
static Timer _refreshTimer = null!;


public static MqttClientOptionsBuilder WithJWT(this MqttClientOptionsBuilder builder, MqttConnectionSettings cs, Func<byte[]> getTokenCallBack, IMqttClient mqttClient, TimeSpan refreshPeriod)
public static MqttClientOptionsBuilder WithJWT(this MqttClientOptionsBuilder builder, MqttConnectionSettings cs, Func<(byte[], TimeSpan)> getTokenCallBack, IMqttClient mqttClient)
{
_getTokenCallBack = getTokenCallBack;

(byte[] token, TimeSpan ts) = getTokenCallBack();
Trace.TraceInformation($"Token expires in {ts.TotalSeconds} seconds");
builder
.WithConnectionSettings(cs)
.WithAuthentication("OAUTH2-JWT", getTokenCallBack());
.WithAuthentication("OAUTH2-JWT", token);

_refreshTimer = new Timer(RefreshToken, mqttClient, 0, Convert.ToInt32(refreshPeriod.TotalMilliseconds));
_refreshTimer = new Timer(RefreshToken, mqttClient, 0, (int)ts.TotalSeconds * 1000);
return builder;
}

Expand All @@ -30,7 +32,7 @@ static void RefreshToken(object? state)
await mqttClient.SendExtendedAuthenticationExchangeDataAsync(
new MqttExtendedAuthenticationExchangeData()
{
AuthenticationData = _getTokenCallBack(),
AuthenticationData = _getTokenCallBack().Item1,
ReasonCode = MQTTnet.Protocol.MqttAuthenticateReasonCode.ReAuthenticate
});
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

IMqttClient mqttClient = new MqttFactory().CreateMqttClient(MqttNetTraceLogger.CreateTraceLogger());
MqttClientConnectResult connAck = await mqttClient!.ConnectAsync(new MqttClientOptionsBuilder()
.WithJWT(cs, GetToken, mqttClient, TimeSpan.FromHours(1))
.WithJWT(cs, GetToken, mqttClient)
.Build());

Console.WriteLine($"Client Connected: {mqttClient.IsConnected} with CONNACK: {connAck.ResultCode} with auth method {mqttClient.Options.AuthenticationMethod}");
Expand All @@ -33,10 +33,10 @@
await Task.Delay(10000);
}

static byte[] GetToken()
static (byte[], TimeSpan) GetToken()
{
DefaultAzureCredential defaultCredential = new();
Console.WriteLine($"---- Get Token {DateTime.Now.ToString("o")} ----");
AccessToken jwt = defaultCredential.GetToken(new TokenRequestContext(new string[] { "https://eventgrid.azure.net/.default" }));
return Encoding.UTF8.GetBytes(jwt.Token);
return (Encoding.UTF8.GetBytes(jwt.Token), jwt.ExpiresOn.Subtract(DateTime.UtcNow));
}

0 comments on commit fd15f58

Please sign in to comment.