Skip to content

Commit

Permalink
Add Function CALL
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrandvd committed Dec 30, 2023
1 parent 38fa294 commit c368f2f
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 19 deletions.
4 changes: 3 additions & 1 deletion src/skUnit.Tests/ScenarioTests/ParseChatScenarioTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Approves that eiffel tower is tall or is positive about it.
### CHECK SemanticCondition
Approves that eiffel tower is tall or is positive about it.
## CALL GetIntent
## CALL Content.GetIntent
```json
{
"input": "$input",
Expand Down Expand Up @@ -147,6 +147,8 @@ It is Json
var agentChatItem = first.ChatItems.ElementAt(1);

Assert.Equal(2, agentChatItem.Assertions.Count);
Assert.Equal("Content", agentChatItem.FunctionCalls.First().PluginName);
Assert.Equal("GetIntent", agentChatItem.FunctionCalls.First().FunctionName);
Assert.Equal(3, agentChatItem.FunctionCalls.First().Assertions.Count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@ public class KernelChatTests : SemanticTestBase
{
public KernelChatTests(ITestOutputHelper output) : base(output)
{

var func = Kernel.CreateFunctionFromPrompt("""
[[INPUT]]
{{$input}}
[[END OF INPUT]]
Get intent of input. Intent should be one of these options: {{$options}}.
INTENT:
""", new PromptExecutionSettings(), "GetIntent");
Kernel.Plugins.AddFromFunctions("MyPlugin", "", new[] { func });
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ Yes it is
### CHECK SemanticCondition
Approves that eiffel tower is tall or is positive about it.

## CALL MyPlugin.GetIntent
```json
{
"options": "Positive,Negative,Neutral"
}
```
## CHECK ContainsAny
Neutral,Positive

## [USER]
What about everest mountain?

Expand Down
83 changes: 70 additions & 13 deletions src/skUnit/Asserts/SemanticKernelAssert_Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Microsoft.SemanticKernel.Connectors.OpenAI;
using skUnit.Exceptions;
using skUnit.Scenarios;
using skUnit.Scenarios.Parsers.Assertions;

namespace skUnit;

Expand Down Expand Up @@ -99,27 +100,83 @@ public async Task CheckChatScenarioAsync(Kernel kernel, ChatScenario scenario, F

foreach (var assertion in chatItem.Assertions)
{
Log($"### CHECK {assertion.AssertionType}");
Log($"{assertion.Description}");

try
await CheckAssertionAsync(assertion, answer);
}

foreach (var functionCall in chatItem.FunctionCalls)
{
Log($"## CALL {functionCall.FunctionName}");
Log(functionCall.ArgumentsText);
Log();

var function = kernel.Plugins[functionCall.PluginName][functionCall.FunctionName];

var arguments = new KernelArguments();

var history = new ChatHistory(chatHistory.Take(chatHistory.Count - 1));
var historyText = string.Join(
Environment.NewLine,
history.Select(c => $"[{c.Role}]: {c.Content}"));

var input = chatHistory.Last().Content;

arguments.Add("input", input);
arguments.Add("history", historyText);

foreach (var functionCallArgument in functionCall.Arguments)
{
await assertion.Assert(Semantic, answer);
Log($"✅ OK");
Log("");
if (functionCallArgument.LiteralValue is not null)
{
arguments.Add(functionCallArgument.Name, functionCallArgument.LiteralValue);
}
else if (functionCallArgument.InputVariable is not null)
{

}
else
{
throw new InvalidOperationException($"""
Invalid function arguments:
{ functionCallArgument}
""" );
}
}
catch (SemanticAssertException exception)

var result = await function.InvokeAsync<string>(kernel, arguments);

Log($"## CALL RESULT {functionCall.FunctionName}");
Log(result);
Log();

foreach (var assertion in functionCall.Assertions)
{
Log("❌ FAIL");
Log("Reason:");
Log(exception.Message);
Log();
throw;
await CheckAssertionAsync(assertion, result ?? "");
}
}
}
}

private async Task CheckAssertionAsync(IKernelAssertion assertion, string answer)
{
Log($"### CHECK {assertion.AssertionType}");
Log($"{assertion.Description}");

try
{
await assertion.Assert(Semantic, answer);
Log($"✅ OK");
Log("");
}
catch (SemanticAssertException exception)
{
Log("❌ FAIL");
Log("Reason:");
Log(exception.Message);
Log();
throw;
}
}

/// <summary>
/// Checks whether all of the <paramref name="scenarios"/> passes on the given <paramref name="kernel"/>
/// using its ChatCompletionService.
Expand Down
4 changes: 3 additions & 1 deletion src/skUnit/Scenarios/ChatScenario.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,17 @@ public override string ToString()

public class FunctionCall
{
public required string PluginName { get; set; }
public required string FunctionName { get; set; }
public List<FunctionCallArgument> Arguments { get; set; } = new();
public string? ArgumentsText { get; set; }

/// <summary>
/// All the assertions that should be checked after the result of InvokeAsync is ready.
/// </summary>
public List<IKernelAssertion> Assertions { get; set; } = new();

public override string ToString() => $"FunctionName({string.Join(",", Arguments.Select(a => a.Name))})";
public override string ToString() => $"{PluginName}{FunctionName}({string.Join(",", Arguments.Select(a => a.Name))})";
}

public class FunctionCallArgument
Expand Down
20 changes: 17 additions & 3 deletions src/skUnit/Scenarios/Parsers/ChatScenarioParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,25 @@ private static bool PackBlock(ChatScenario scenario, string newBlock, ref string
//if (!match.Success)
// throw new InvalidOperationException($"Call is not valid: {key}");

var function = key;
var functionText = key;

if (string.IsNullOrWhiteSpace(function))
if (string.IsNullOrWhiteSpace(functionText))
{
throw new InvalidOperationException($"CALL function name is null");
}

var callParts = functionText.Split('.');
if (callParts.Length != 2)
{
throw new InvalidOperationException($"""
Invalid function call. It should be in Plugin.Function format:
{ functionText}
""" );
}

var plugin = callParts[0];
var function = callParts[1];

var arguments = new List<FunctionCallArgument>();

if (!string.IsNullOrWhiteSpace(contentText))
Expand Down Expand Up @@ -179,8 +191,10 @@ private static bool PackBlock(ChatScenario scenario, string newBlock, ref string

scenario.ChatItems.Last().FunctionCalls.Add(new FunctionCall()
{
PluginName = plugin,
FunctionName = function,
Arguments = arguments
Arguments = arguments,
ArgumentsText = contentText
});
}
else if (currentBlock == "CHECK")
Expand Down

0 comments on commit c368f2f

Please sign in to comment.