Skip to content

Commit

Permalink
Fix parallel_tool_cals parameter not being sent
Browse files Browse the repository at this point in the history
  • Loading branch information
selfdocumentingcode committed Sep 6, 2024
1 parent d751426 commit 1ecdc2e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
4 changes: 2 additions & 2 deletions Kattbot.Common/Models/KattGpt/ChatCompletionCreateRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public record ChatCompletionCreateRequest
/// https://platform.openai.com/docs/api-reference/chat/create#chat-create-parallel_tool_calls
/// </summary>
[JsonPropertyName("parallel_tool_calls")]
public bool ParallelToolCalls { get; set; }
public bool? ParallelToolCalls { get; set; }

/// <summary>
/// Gets or sets what sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more
Expand Down Expand Up @@ -135,4 +135,4 @@ public record ChatCompletionCreateRequest
/// </summary>
[JsonPropertyName("user")]
public string? User { get; set; }
}
}

Check warning on line 138 in Kattbot.Common/Models/KattGpt/ChatCompletionCreateRequest.cs

View workflow job for this annotation

GitHub Actions / Build

99 changes: 50 additions & 49 deletions Kattbot/NotificationHandlers/KattGptMessageHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,35 +65,35 @@ public KattGptMessageHandler(

public async Task Handle(MessageCreatedNotification notification, CancellationToken cancellationToken)
{
MessageCreatedEventArgs args = notification.EventArgs;
DiscordMessage message = args.Message;
DiscordUser author = args.Author;
DiscordChannel channel = args.Message.Channel ?? throw new Exception("Channel is null.");
MessageCreatedEventArgs? args = notification.EventArgs;
DiscordMessage? message = args.Message;
DiscordUser? author = args.Author;
DiscordChannel? channel = args.Message.Channel ?? throw new Exception("Channel is null.");

if (!ShouldHandleMessage(message)) return;

try
{
List<ChatCompletionMessage> systemPromptsMessages = _kattGptService.BuildSystemPromptsMessages(channel);
ChatCompletionFunction chatCompletionFunction = DalleToolBuilder.BuildDalleImageToolDefinition().Function;
List<ChatCompletionMessage>? systemPromptsMessages = _kattGptService.BuildSystemPromptsMessages(channel);
ChatCompletionFunction? chatCompletionFunction = DalleToolBuilder.BuildDalleImageToolDefinition().Function;
List<ChatCompletionMessage> newContextMessages = [];

KattGptChannelContext channelContext = GetOrCreateCachedContext(
KattGptChannelContext? channelContext = GetOrCreateCachedContext(
channel,
systemPromptsMessages,
chatCompletionFunction);

bool shouldReplyToMessage = ShouldReplyToMessage(message);

string recipientMarker = shouldReplyToMessage
string? recipientMarker = shouldReplyToMessage
? RecipientMarkerToYou
: RecipientMarkerToOthers;

// Add new message from notification
string newMessageUser = author.GetDisplayName();
string newMessageContent = message.SubstituteMentions();
string? newMessageUser = author.GetDisplayName();
string? newMessageContent = message.SubstituteMentions();

ChatCompletionMessage newUserMessage =
ChatCompletionMessage? newUserMessage =
ChatCompletionMessage.AsUser($"{newMessageUser}{recipientMarker}: {newMessageContent}");

newContextMessages.Add(newUserMessage);
Expand All @@ -106,22 +106,22 @@ public async Task Handle(MessageCreatedNotification notification, CancellationTo

await channel.TriggerTypingAsync();

ChatCompletionCreateRequest request = BuildRequest(
ChatCompletionCreateRequest? request = BuildRequest(
systemPromptsMessages,
channelContext,
allowToolCalls: true,
newUserMessage);

ChatCompletionCreateResponse response = await _chatGpt.ChatCompletionCreate(request);
ChatCompletionCreateResponse? response = await _chatGpt.ChatCompletionCreate(request);

ChatCompletionChoice chatGptResponse = response.Choices[0];
ChatCompletionMessage chatGptResponseMessage = chatGptResponse.Message;
ChatCompletionChoice? chatGptResponse = response.Choices[0];
ChatCompletionMessage? chatGptResponseMessage = chatGptResponse.Message;

newContextMessages.Add(chatGptResponseMessage);

if (chatGptResponse.FinishReason == ChoiceFinishReason.tool_calls)
{
List<ChatCompletionMessage> toolResponseMessages = await HandleToolCallResponse(
List<ChatCompletionMessage>? toolResponseMessages = await HandleToolCallResponse(
message,
systemPromptsMessages,
channelContext,
Expand Down Expand Up @@ -156,6 +156,9 @@ private static ChatCompletionCreateRequest BuildRequest(
? [DalleToolBuilder.BuildDalleImageToolDefinition()]
: null;

// Not allowed to include parallel tool calls field when tools is null
bool? parallelToolCalls = allowToolCalls ? false : null;

// Collect request messages
var requestMessages = new List<ChatCompletionMessage>();
requestMessages.AddRange(systemPromptsMessages);
Expand All @@ -170,7 +173,7 @@ private static ChatCompletionCreateRequest BuildRequest(
Temperature = DefaultTemperature,
MaxTokens = MaxTokensToGenerate,
Tools = chatCompletionTools,
ParallelToolCalls = false,
ParallelToolCalls = parallelToolCalls,
};

return request;
Expand All @@ -184,13 +187,13 @@ private static async Task SendImageReply(
{
const int maxFilenameLength = 32;

string truncatedFilename = filename.Length > maxFilenameLength
string? truncatedFilename = filename.Length > maxFilenameLength
? filename[..maxFilenameLength]
: filename;

string safeFilename = truncatedFilename.ToSafeFilename(imageStream.FileExtension);
string? safeFilename = truncatedFilename.ToSafeFilename(imageStream.FileExtension);

DiscordMessageBuilder mb = new DiscordMessageBuilder()
DiscordMessageBuilder? mb = new DiscordMessageBuilder()
.AddFile(safeFilename, imageStream.MemoryStream)
.WithContent(responseMessageText);

Expand All @@ -199,11 +202,11 @@ private static async Task SendImageReply(

private static async Task SendTextReply(string responseMessage, DiscordMessage messageToReplyTo)
{
List<string> messageChunks = responseMessage.SplitString(DiscordConstants.MaxMessageLength, MessageSplitToken);
List<string>? messageChunks = responseMessage.SplitString(DiscordConstants.MaxMessageLength, MessageSplitToken);

DiscordMessage nextMessageToReplyTo = messageToReplyTo;
DiscordMessage? nextMessageToReplyTo = messageToReplyTo;

foreach (string messageChunk in messageChunks)
foreach (string? messageChunk in messageChunks)
{
nextMessageToReplyTo = await nextMessageToReplyTo.RespondAsync(messageChunk);
}
Expand All @@ -214,11 +217,11 @@ private static async Task SendToolUseReply(
ChatCompletionMessage chatGptToolCallResponse,
string prompt)
{
string toolUseText = string.Format(MessageToolUseTemplate, prompt);
string responseMessageText = chatGptToolCallResponse.Content ?? string.Empty;
string? toolUseText = string.Format(MessageToolUseTemplate, prompt);
string? responseMessageText = chatGptToolCallResponse.Content ?? string.Empty;

// Tool call messages have content only sometimes
string responseTextWithToolUse = !string.IsNullOrWhiteSpace(responseMessageText)
string? responseTextWithToolUse = !string.IsNullOrWhiteSpace(responseMessageText)
? $"{responseMessageText.TrimEnd()}\n\n{toolUseText}"
: toolUseText;

Expand All @@ -234,14 +237,14 @@ private async Task<ImageStreamResult> GetDalleResult(string prompt, string userI
User = userId,
};

CreateImageResponse response = await _dalleHttpClient.CreateImage(imageRequest);
CreateImageResponse? response = await _dalleHttpClient.CreateImage(imageRequest);
if (response.Data == null || !response.Data.Any()) throw new Exception("Empty result");

ImageResponseUrlData imageUrl = response.Data.First();
ImageResponseUrlData? imageUrl = response.Data.First();

Image image = await _imageService.DownloadImage(imageUrl.Url);
Image? image = await _imageService.DownloadImage(imageUrl.Url);

ImageStreamResult imageStream = await _imageService.GetImageStream(image);
ImageStreamResult? imageStream = await _imageService.GetImageStream(image);

return imageStream;
}
Expand All @@ -254,50 +257,48 @@ private async Task<List<ChatCompletionMessage>> HandleToolCallResponse(
{
List<ChatCompletionMessage> responseMessages = [];

ChatCompletionToolCall toolCall =
ChatCompletionToolCall? toolCall =
chatGptToolCallResponse.ToolCalls?[0] ?? throw new Exception("Tool call is null.");

if (chatGptToolCallResponse.ToolCalls.Count > 1)
{
throw new Exception($"Too many tool calls: {chatGptToolCallResponse.ToolCalls.Count.ToString()}");
}

// Parse the function call arguments
string functionCallArguments = toolCall.FunctionCall.Arguments;
string? functionCallArguments = toolCall.FunctionCall.Arguments;

JsonNode parsedArguments = JsonNode.Parse(functionCallArguments)
?? throw new Exception("Could not parse function call arguments.");
JsonNode? parsedArguments = JsonNode.Parse(functionCallArguments)
?? throw new Exception("Could not parse function call arguments.");

string prompt = parsedArguments["prompt"]?.GetValue<string>()
?? throw new Exception("Function call arguments are invalid.");
string? prompt = parsedArguments["prompt"]?.GetValue<string>()
?? throw new Exception("Function call arguments are invalid.");

// Send the tool use message as a confirmation
await SendToolUseReply(message, chatGptToolCallResponse, prompt);

var authorId = message.Author!.Id.ToString();

ImageStreamResult dalleResult = await GetDalleResult(prompt, authorId);
ImageStreamResult? dalleResult = await GetDalleResult(prompt, authorId);

// Build the function call result message
var functionCallResult = $"An image of {prompt} has been generated and attached to this message.";

ChatCompletionMessage functionCallResultMessage =
ChatCompletionMessage? functionCallResultMessage =
ChatCompletionMessage.AsToolCallResult(functionCallResult, toolCall.Id);

// Force a content value for the ChatGPT response due the api not allowing nulls even though it says it does
chatGptToolCallResponse.Content ??= "null";

ChatCompletionCreateRequest request = BuildRequest(
ChatCompletionCreateRequest? request = BuildRequest(
systemPromptsMessages,
channelContext,
allowToolCalls: false,
chatGptToolCallResponse,
functionCallResultMessage);

ChatCompletionCreateResponse response = await _chatGpt.ChatCompletionCreate(request);
ChatCompletionCreateResponse? response = await _chatGpt.ChatCompletionCreate(request);

// Handle new response
ChatCompletionMessage functionCallResponse = response.Choices[0].Message;
ChatCompletionMessage? functionCallResponse = response.Choices[0].Message;

await SendImageReply(functionCallResponse.Content!, message, prompt, dalleResult);

Expand All @@ -320,7 +321,7 @@ private KattGptChannelContext GetOrCreateCachedContext(
List<ChatCompletionMessage> systemPromptsMessages,
ChatCompletionFunction chatCompletionFunction)
{
string cacheKey = KattGptChannelCache.KattGptChannelCacheKey(channel.Id);
string? cacheKey = KattGptChannelCache.KattGptChannelCacheKey(channel.Id);

KattGptChannelContext? channelContext = _cache.GetCache(cacheKey);

Expand Down Expand Up @@ -357,14 +358,14 @@ private bool ShouldHandleMessage(DiscordMessage message)
if (!IsRelevantMessage(message)) return false;

string[] commandPrefixes = [_botOptions.CommandPrefix, _botOptions.AlternateCommandPrefix];
string messageContent = message.Content.ToLower().TrimStart();
string? messageContent = message.Content.ToLower().TrimStart();

bool messageStartsWithCommandPrefix = commandPrefixes.Any(messageContent.StartsWith);

if (messageStartsWithCommandPrefix)
return false;

DiscordChannel channel = message.Channel!;
DiscordChannel? channel = message.Channel!;

ChannelOptions? channelOptions = _kattGptService.GetChannelOptions(channel);

Expand All @@ -374,7 +375,7 @@ private bool ShouldHandleMessage(DiscordMessage message)
if (!channelOptions.AlwaysOn) return true;

// otherwise check if the message does not start with the MetaMessagePrefix
string[] metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
string[]? metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
bool messageStartsWithMetaMessagePrefix = metaMessagePrefixes.Any(messageContent.StartsWith);

// if it does, return false
Expand All @@ -388,7 +389,7 @@ private bool ShouldHandleMessage(DiscordMessage message)
/// <returns>True if KattGpt should reply.</returns>
private bool ShouldReplyToMessage(DiscordMessage message)
{
DiscordChannel channel = message.Channel!;
DiscordChannel? channel = message.Channel!;

ChannelOptions? channelOptions = _kattGptService.GetChannelOptions(channel);

Expand All @@ -409,7 +410,7 @@ private bool ShouldReplyToMessage(DiscordMessage message)
}

// otherwise check if the message does not start with the MetaMessagePrefix
string[] metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
string[]? metaMessagePrefixes = _kattGptOptions.AlwaysOnIgnoreMessagePrefixes;
bool messageStartsWithMetaMessagePrefix = metaMessagePrefixes.Any(message.Content.TrimStart().StartsWith);

// if it does, return false
Expand Down

0 comments on commit 1ecdc2e

Please sign in to comment.