Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the logging generator in LoggingChatClient / LoggingEmbeddingGenerator #5508

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

#pragma warning disable EA0000 // Use source generated logging methods for improved performance
#pragma warning disable CA2254 // Template should be a static expression

namespace Microsoft.Extensions.AI;

/// <summary>A delegating chat client that logs chat operations to an <see cref="ILogger"/>.</summary>
public class LoggingChatClient : DelegatingChatClient
public partial class LoggingChatClient : DelegatingChatClient
{
/// <summary>An <see cref="ILogger"/> instance used for all logging.</summary>
private readonly ILogger _logger;
Expand Down Expand Up @@ -45,7 +42,18 @@ public JsonSerializerOptions JsonSerializerOptions
public override async Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
LogStart(chatMessages, options);
if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogInvokedSensitive(nameof(CompleteAsync), AsJson(chatMessages), AsJson(options), AsJson(Metadata));
}
else
{
LogInvoked(nameof(CompleteAsync));
}
}

try
{
var completion = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
Expand All @@ -54,20 +62,24 @@ public override async Task<ChatCompletion> CompleteAsync(
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (completion, _jsonSerializerOptions), null, static (state, _) =>
$"CompleteAsync completed: {JsonSerializer.Serialize(state.completion, state._jsonSerializerOptions.GetTypeInfo(typeof(ChatCompletion)))}");
LogCompletedSensitive(nameof(CompleteAsync), AsJson(completion));
}
else
{
_logger.LogDebug("CompleteAsync completed.");
LogCompleted(nameof(CompleteAsync));
}
}

return completion;
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled(nameof(CompleteAsync));
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "CompleteAsync failed.");
LogInvocationFailed(nameof(CompleteAsync), ex);
throw;
}
}
Expand All @@ -76,16 +88,31 @@ public override async Task<ChatCompletion> CompleteAsync(
public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
LogStart(chatMessages, options);
if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogInvokedSensitive(nameof(CompleteStreamingAsync), AsJson(chatMessages), AsJson(options), AsJson(Metadata));
}
else
{
LogInvoked(nameof(CompleteStreamingAsync));
}
}

IAsyncEnumerator<StreamingChatCompletionUpdate> e;
try
{
e = base.CompleteStreamingAsync(chatMessages, options, cancellationToken).GetAsyncEnumerator(cancellationToken);
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled(nameof(CompleteStreamingAsync));
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "CompleteStreamingAsync failed.");
LogInvocationFailed(nameof(CompleteStreamingAsync), ex);
throw;
}

Expand All @@ -103,52 +130,63 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt

update = e.Current;
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled(nameof(CompleteStreamingAsync));
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "CompleteStreamingAsync failed.");
LogInvocationFailed(nameof(CompleteStreamingAsync), ex);
throw;
}

if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (update, _jsonSerializerOptions), null, static (state, _) =>
$"CompleteStreamingAsync received update: {JsonSerializer.Serialize(state.update, state._jsonSerializerOptions.GetTypeInfo(typeof(StreamingChatCompletionUpdate)))}");
LogStreamingUpdateSensitive(AsJson(update));
}
else
{
_logger.LogDebug("CompleteStreamingAsync received update.");
LogStreamingUpdate();
}
}

yield return update;
}

_logger.LogDebug("CompleteStreamingAsync completed.");
LogCompleted(nameof(CompleteStreamingAsync));
}
finally
{
await e.DisposeAsync().ConfigureAwait(false);
}
}

private void LogStart(IList<ChatMessage> chatMessages, ChatOptions? options, [CallerMemberName] string? methodName = null)
{
if (_logger.IsEnabled(LogLevel.Debug))
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (methodName, chatMessages, options, this), null, static (state, _) =>
$"{state.methodName} invoked: " +
$"Messages: {JsonSerializer.Serialize(state.chatMessages, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(IList<ChatMessage>)))}. " +
$"Options: {JsonSerializer.Serialize(state.options, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatOptions)))}. " +
$"Metadata: {JsonSerializer.Serialize(state.Item4.Metadata, state.Item4._jsonSerializerOptions.GetTypeInfo(typeof(ChatClientMetadata)))}.");
}
else
{
_logger.LogDebug($"{methodName} invoked.");
}
}
}
private string AsJson<T>(T value) => JsonSerializer.Serialize(value, _jsonSerializerOptions.GetTypeInfo(typeof(T)));

[LoggerMessage(LogLevel.Debug, "{MethodName} invoked.")]
private partial void LogInvoked(string methodName);

[LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {ChatMessages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")]
private partial void LogInvokedSensitive(string methodName, string chatMessages, string chatOptions, string chatClientMetadata);
Comment on lines +172 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it's already checked if the LogLevel is enabled, there's no need for the SGen to generate that check too. So set

Suggested change
[LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {ChatMessages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.")]
private partial void LogInvokedSensitive(string methodName, string chatMessages, string chatOptions, string chatClientMetadata);
[LoggerMessage(LogLevel.Trace, "{MethodName} invoked: {ChatMessages}. Options: {ChatOptions}. Metadata: {ChatClientMetadata}.", SkipEnabledCheck = true)]
private partial void LogInvokedSensitive(string methodName, string chatMessages, string chatOptions, string chatClientMetadata);

in order to avoid the redundant check be emitted.

Same on other places.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do so because these are only Debug/Trace events, and in the case where such events are enabled, the extra interface call will be dwarfed by all the JSON serialization and other overheads. And while I could just enable it anyway, not all of the call sites are currently guarded, which makes it complicated for someone maintaining the code to know which call sites need guarding and which don't. Which then leads to consistently putting SkipEnabledCheck on all of the log methods and consistently guarding all call sites, but that in turn makes the code more expensive. Seemed best to just keep it simple and pay for the extra enabled check when there's already a ton of overhead / when debugging.


[LoggerMessage(LogLevel.Debug, "{MethodName} completed.")]
private partial void LogCompleted(string methodName);

[LoggerMessage(LogLevel.Trace, "{MethodName} completed: {ChatCompletion}.")]
private partial void LogCompletedSensitive(string methodName, string chatCompletion);

[LoggerMessage(LogLevel.Debug, "CompleteStreamingAsync received update.")]
private partial void LogStreamingUpdate();

[LoggerMessage(LogLevel.Trace, "CompleteStreamingAsync received update: {StreamingChatCompletionUpdate}")]
private partial void LogStreamingUpdateSensitive(string streamingChatCompletionUpdate);

[LoggerMessage(LogLevel.Debug, "{MethodName} canceled.")]
private partial void LogInvocationCanceled(string methodName);

[LoggerMessage(LogLevel.Error, "{MethodName} failed.")]
private partial void LogInvocationFailed(string methodName, Exception error);
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
using Microsoft.Extensions.Logging;
using Microsoft.Shared.Diagnostics;

#pragma warning disable EA0000 // Use source generated logging methods for improved performance

namespace Microsoft.Extensions.AI;

/// <summary>A delegating embedding generator that logs embedding generation operations to an <see cref="ILogger"/>.</summary>
/// <typeparam name="TInput">Specifies the type of the input passed to the generator.</typeparam>
/// <typeparam name="TEmbedding">Specifies the type of the embedding instance produced by the generator.</typeparam>
public class LoggingEmbeddingGenerator<TInput, TEmbedding> : DelegatingEmbeddingGenerator<TInput, TEmbedding>
public partial class LoggingEmbeddingGenerator<TInput, TEmbedding> : DelegatingEmbeddingGenerator<TInput, TEmbedding>
where TEmbedding : Embedding
{
/// <summary>An <see cref="ILogger"/> instance used for all logging.</summary>
Expand Down Expand Up @@ -50,33 +48,48 @@ public override async Task<GeneratedEmbeddings<TEmbedding>> GenerateAsync(IEnume
{
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.Log(LogLevel.Trace, 0, (values, options, this), null, static (state, _) =>
"GenerateAsync invoked: " +
$"Values: {JsonSerializer.Serialize(state.values, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(IEnumerable<TInput>)))}. " +
$"Options: {JsonSerializer.Serialize(state.options, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGenerationOptions)))}. " +
$"Metadata: {JsonSerializer.Serialize(state.Item3.Metadata, state.Item3._jsonSerializerOptions.GetTypeInfo(typeof(EmbeddingGeneratorMetadata)))}.");
LogInvokedSensitive(AsJson(values), AsJson(options), AsJson(Metadata));
}
else
{
_logger.LogDebug("GenerateAsync invoked.");
LogInvoked();
}
}

try
{
var embeddings = await base.GenerateAsync(values, options, cancellationToken).ConfigureAwait(false);

if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("GenerateAsync generated {Count} embedding(s).", embeddings.Count);
}
LogCompleted(embeddings.Count);

return embeddings;
}
catch (Exception ex) when (ex is not OperationCanceledException)
catch (OperationCanceledException)
{
LogInvocationCanceled();
throw;
}
catch (Exception ex)
{
_logger.LogError(ex, "GenerateAsync failed.");
LogInvocationFailed(ex);
throw;
}
}

private string AsJson<T>(T value) => JsonSerializer.Serialize(value, _jsonSerializerOptions.GetTypeInfo(typeof(T)));

[LoggerMessage(LogLevel.Debug, "GenerateAsync invoked.")]
private partial void LogInvoked();

[LoggerMessage(LogLevel.Trace, "GenerateAsync invoked: {Values}. Options: {EmbeddingGenerationOptions}. Metadata: {EmbeddingGeneratorMetadata}.")]
private partial void LogInvokedSensitive(string values, string embeddingGenerationOptions, string embeddingGeneratorMetadata);

[LoggerMessage(LogLevel.Debug, "GenerateAsync generated {EmbeddingsCount} embedding(s).")]
private partial void LogCompleted(int embeddingsCount);

[LoggerMessage(LogLevel.Debug, "GenerateAsync canceled.")]
private partial void LogInvocationCanceled();

[LoggerMessage(LogLevel.Error, "GenerateAsync failed.")]
private partial void LogInvocationFailed(Exception error);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
<PropertyGroup>
<InjectSharedCollectionExtensions>true</InjectSharedCollectionExtensions>
<InjectSharedEmptyCollections>true</InjectSharedEmptyCollections>
<DisableMicrosoftExtensionsLoggingSourceGenerator>false</DisableMicrosoftExtensionsLoggingSourceGenerator>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ await chatClient.CompleteAsync(

Assert.Collection(logger.Entries,
entry => Assert.Contains("What is the current secret number?", entry.Message),
entry => Assert.Contains("\"name\":\"GetSecretNumber\"", entry.Message),
entry => Assert.Contains($"\"result\":{secretNumber}", entry.Message),
entry => Assert.Contains("\"name\": \"GetSecretNumber\"", entry.Message),
entry => Assert.Contains($"\"result\": {secretNumber}", entry.Message),
entry => Assert.Contains(secretNumber.ToString(), entry.Message));
}

Expand All @@ -528,8 +528,8 @@ public virtual async Task Logging_LogsFunctionCalls_Streaming()
}

Assert.Contains(logger.Entries, e => e.Message.Contains("What is the current secret number?"));
Assert.Contains(logger.Entries, e => e.Message.Contains("\"name\":\"GetSecretNumber\""));
Assert.Contains(logger.Entries, e => e.Message.Contains($"\"result\":{secretNumber}"));
Assert.Contains(logger.Entries, e => e.Message.Contains("\"name\": \"GetSecretNumber\""));
Assert.Contains(logger.Entries, e => e.Message.Contains($"\"result\": {secretNumber}"));
}

[ConditionalFact]
Expand Down
Loading