Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow change of roles after starting the interaction #120

Merged
merged 4 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 30 additions & 63 deletions Runtime/LLMChatTemplates.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ namespace LLMUnity
{
public abstract class ChatTemplate
{
public string playerName;
public string AIName;

public static string DefaultTemplate;
public static Type[] templateClasses;
public static Dictionary<string, Type> templates;
Expand Down Expand Up @@ -39,7 +36,7 @@ static ChatTemplate()
chatTemplates = new Dictionary<string, string>();
foreach (Type templateClass in templateClasses)
{
ChatTemplate template = (ChatTemplate)Activator.CreateInstance(templateClass, "", "");
ChatTemplate template = (ChatTemplate)Activator.CreateInstance(templateClass);
if (templates.ContainsKey(template.GetName())) Debug.LogError($"{template.GetName()} already in templates");
templates[template.GetName()] = templateClass;
if (templatesDescription.ContainsKey(template.GetDescription())) Debug.LogError($"{template.GetDescription()} already in templatesDescription");
Expand All @@ -57,12 +54,6 @@ static ChatTemplate()
}
}

public ChatTemplate(string playerName = "user", string AIName = "assistant")
{
this.playerName = playerName;
this.AIName = AIName;
}

public static string FromName(string name)
{
string nameLower = name.ToLower();
Expand Down Expand Up @@ -108,28 +99,28 @@ public static string FromGGUF(string path)
return DefaultTemplate;
}

public static ChatTemplate GetTemplate(string template, string playerName, string AIName)
public static ChatTemplate GetTemplate(string template)
{
return (ChatTemplate)Activator.CreateInstance(templates[template], playerName, AIName);
return (ChatTemplate)Activator.CreateInstance(templates[template]);
}

public abstract string GetName();
public abstract string GetDescription();
public virtual string[] GetNameMatches() { return new string[] {}; }
public virtual string[] GetChatTemplateMatches() { return new string[] {}; }
public abstract string[] GetStop();
public abstract string[] GetStop(string playerName, string AIName);

protected virtual string PromptPrefix() { return ""; }
protected virtual string SystemPrefix() { return ""; }
protected virtual string SystemSuffix() { return ""; }
protected virtual string PlayerPrefix() { return ""; }
protected virtual string AIPrefix() { return ""; }
protected virtual string PlayerPrefix(string playerName) { return ""; }
protected virtual string AIPrefix(string AIName) { return ""; }
protected virtual string PrefixMessageSeparator() { return ""; }
protected virtual string RequestPrefix() { return ""; }
protected virtual string RequestSuffix() { return ""; }
protected virtual string PairSuffix() { return ""; }

public virtual string ComputePrompt(List<ChatMessage> messages)
public virtual string ComputePrompt(List<ChatMessage> messages, string AIName)
{
string chatPrompt = PromptPrefix();
string systemPrompt = "";
Expand All @@ -143,21 +134,13 @@ public virtual string ComputePrompt(List<ChatMessage> messages)
{
chatPrompt += RequestPrefix();
if (i == 1 && systemPrompt != "") chatPrompt += systemPrompt;
if (messages[i].role != playerName)
{
Debug.Log($"Role was {messages[i].role}, was expecting {playerName}");
}
chatPrompt += PlayerPrefix() + PrefixMessageSeparator() + messages[i].content + RequestSuffix();
chatPrompt += PlayerPrefix(messages[i].role) + PrefixMessageSeparator() + messages[i].content + RequestSuffix();
if (i < messages.Count - 1)
{
if (messages[i + 1].role != AIName)
{
Debug.Log($"Role was {messages[i + 1].role}, was expecting {AIName}");
}
chatPrompt += AIPrefix() + PrefixMessageSeparator() + messages[i + 1].content + PairSuffix();
chatPrompt += AIPrefix(messages[i + 1].role) + PrefixMessageSeparator() + messages[i + 1].content + PairSuffix();
}
}
chatPrompt += AIPrefix();
chatPrompt += AIPrefix(AIName);
return chatPrompt;
}

Expand All @@ -175,30 +158,26 @@ public string[] AddStopNewlines(string[] stop)

public class ChatMLTemplate : ChatTemplate
{
public ChatMLTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "chatml"; }
public override string GetDescription() { return "chatml (best overall)"; }
public override string[] GetNameMatches() { return new string[] {"chatml", "hermes"}; }
public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}; }

protected override string SystemPrefix() { return "<|im_start|>system\n"; }
protected override string SystemSuffix() { return "<|im_end|>\n"; }
protected override string PlayerPrefix() { return $"<|im_start|>{playerName}\n"; }
protected override string AIPrefix() { return $"<|im_start|>{AIName}\n"; }
protected override string PlayerPrefix(string playerName) { return $"<|im_start|>{playerName}\n"; }
protected override string AIPrefix(string AIName) { return $"<|im_start|>{AIName}\n"; }
protected override string RequestSuffix() { return "<|im_end|>\n"; }
protected override string PairSuffix() { return "<|im_end|>\n"; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { "<|im_start|>", "<|im_end|>" });
}
}

public class LLama2Template : ChatTemplate
{
public LLama2Template(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "llama"; }
public override string GetDescription() { return "llama"; }

Expand All @@ -208,34 +187,30 @@ public LLama2Template(string playerName = "user", string AIName = "assistant") :
protected override string RequestSuffix() { return " [/INST]"; }
protected override string PairSuffix() { return " </s>"; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { "[INST]", "[/INST]" });
}
}

public class LLama2ChatTemplate : LLama2Template
{
public LLama2ChatTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "llama chat"; }
public override string GetDescription() { return "llama (modified for chat)"; }
public override string[] GetNameMatches() { return new string[] {"llama"}; }

protected override string PlayerPrefix() { return "### " + playerName + ":"; }
protected override string AIPrefix() { return "### " + AIName + ":"; }
protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; }
protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; }
protected override string PrefixMessageSeparator() { return " "; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { "[INST]", "[/INST]", "###" });
}
}

public class MistralInstructTemplate : ChatTemplate
{
public MistralInstructTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "mistral instruct"; }
public override string GetDescription() { return "mistral instruct"; }

Expand All @@ -246,90 +221,82 @@ public MistralInstructTemplate(string playerName = "user", string AIName = "assi
protected override string RequestSuffix() { return " [/INST]"; }
protected override string PairSuffix() { return "</s>"; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { "[INST]", "[/INST]" });
}
}

public class MistralChatTemplate : MistralInstructTemplate
{
public MistralChatTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "mistral chat"; }
public override string GetDescription() { return "mistral (modified for chat)"; }
public override string[] GetNameMatches() { return new string[] {"mistral"}; }
public override string[] GetChatTemplateMatches() { return new string[] {"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"}; }

protected override string PlayerPrefix() { return "### " + playerName + ":"; }
protected override string AIPrefix() { return "### " + AIName + ":"; }
protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; }
protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; }
protected override string PrefixMessageSeparator() { return " "; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { "[INST]", "[/INST]", "###" });
}
}

public class AlpacaTemplate : ChatTemplate
{
public AlpacaTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "alpaca"; }
public override string GetDescription() { return "alpaca (best alternative)"; }
public override string[] GetNameMatches() { return new string[] {"alpaca"}; }

protected override string SystemSuffix() { return "\n\n"; }
protected override string RequestSuffix() { return "\n"; }
protected override string PlayerPrefix() { return "### " + playerName + ":"; }
protected override string AIPrefix() { return "### " + AIName + ":"; }
protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; }
protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; }
protected override string PrefixMessageSeparator() { return " "; }
protected override string PairSuffix() { return "\n"; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { "###" });
}
}

public class Phi2Template : ChatTemplate
{
public Phi2Template(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "phi"; }
public override string GetDescription() { return "phi"; }
public override string[] GetNameMatches() { return new string[] {"phi"}; }

protected override string SystemSuffix() { return "\n\n"; }
protected override string RequestSuffix() { return "\n"; }
protected override string PlayerPrefix() { return playerName + ":"; }
protected override string AIPrefix() { return AIName + ":"; }
protected override string PlayerPrefix(string playerName) { return playerName + ":"; }
protected override string AIPrefix(string AIName) { return AIName + ":"; }
protected override string PrefixMessageSeparator() { return " "; }
protected override string PairSuffix() { return "\n"; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { playerName + ":", AIName + ":" });
}
}

public class ZephyrTemplate : ChatTemplate
{
public ZephyrTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {}

public override string GetName() { return "zephyr"; }
public override string GetDescription() { return "zephyr"; }
public override string[] GetNameMatches() { return new string[] {"zephyr"}; }
public override string[] GetChatTemplateMatches() { return new string[] {"{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"}; }

protected override string SystemPrefix() { return "<|system|>\n"; }
protected override string SystemSuffix() { return "</s>\n"; }
protected override string PlayerPrefix() { return $"<|user|>\n"; }
protected override string AIPrefix() { return $"<|assistant|>\n"; }
protected override string PlayerPrefix(string playerName) { return $"<|user|>\n"; }
protected override string AIPrefix(string AIName) { return $"<|assistant|>\n"; }
protected override string RequestSuffix() { return "</s>\n"; }
protected override string PairSuffix() { return "</s>\n"; }

public override string[] GetStop()
public override string[] GetStop(string playerName, string AIName)
{
return AddStopNewlines(new string[] { $"<|user|>", $"<|assistant|>" });
}
Expand Down
20 changes: 12 additions & 8 deletions Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
Expand Down Expand Up @@ -88,7 +89,6 @@ public class LLMClient : MonoBehaviour
private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
private string previousEndpoint;
public bool setNKeepToPrompt = true;
private List<string> stopAll;
private List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
static object chatPromptLock = new object();
static object chatAddLock = new object();
Expand Down Expand Up @@ -117,7 +117,7 @@ public LLM GetServer()
public virtual void SetTemplate(string templateName)
{
chatTemplate = templateName;
template = ChatTemplate.GetTemplate(templateName, playerName, AIName);
LoadTemplate();
}

private void Reset()
Expand Down Expand Up @@ -197,10 +197,7 @@ private void SetNKeep(List<int> tokens)

private void LoadTemplate()
{
template = ChatTemplate.GetTemplate(chatTemplate, playerName, AIName);
stopAll = new List<string>();
stopAll.AddRange(template.GetStop());
if (stop != null) stopAll.AddRange(stop);
template = ChatTemplate.GetTemplate(chatTemplate);
}

#if UNITY_EDITOR
Expand All @@ -210,6 +207,13 @@ public async void SetGrammar(string path)
}

#endif
List<string> GetStopwords()
{
List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
if (stop != null) stopAll.AddRange(stop);
return stopAll;
}

public ChatRequest GenerateRequest(string prompt)
{
// setup the request struct
Expand All @@ -222,7 +226,7 @@ public ChatRequest GenerateRequest(string prompt)
chatRequest.n_predict = numPredict;
chatRequest.n_keep = nKeep;
chatRequest.stream = stream;
chatRequest.stop = stopAll;
chatRequest.stop = GetStopwords();
chatRequest.tfs_z = tfsZ;
chatRequest.typical_p = typicalP;
chatRequest.repeat_penalty = repeatPenalty;
Expand Down Expand Up @@ -298,7 +302,7 @@ public async Task<string> Chat(string question, Callback<string> callback = null
string json;
lock (chatPromptLock) {
AddPlayerMessage(question);
string prompt = template.ComputePrompt(chat);
string prompt = template.ComputePrompt(chat, AIName);
json = JsonUtility.ToJson(GenerateRequest(prompt));
chat.RemoveAt(chat.Count - 1);
}
Expand Down
20 changes: 16 additions & 4 deletions Tests/Runtime/TestLLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ public async Task RunTests()
TestInitParameters((await llm.Tokenize(prompt)).Count, 1);
TestWarmup();
await llm.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat);
TestPostChat();
TestPostChat(3);
llm.SetPrompt(llm.prompt);
llm.AIName = "False response";
await llm.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat2);
TestPostChat(3);
await llm.Chat("bye!");
TestPostChat(5);
prompt = "How are you?";
llm.SetPrompt(prompt);
await llm.Chat("hi");
Expand Down Expand Up @@ -122,7 +128,7 @@ public void TestAlive()
public void TestInitParameters(int nkeep, int chats)
{
Assert.That(llm.nKeep == nkeep);
Assert.That(llm.template.GetStop().Length > 0);
Assert.That(llm.template.GetStop(llm.playerName, llm.AIName).Length > 0);
Assert.That(llm.GetChat().Count == chats);
}

Expand All @@ -142,9 +148,15 @@ public void TestChat(string reply)
Assert.That(reply.Trim() == AIReply);
}

public void TestPostChat()
public void TestChat2(string reply)
{
Assert.That(llm.GetChat().Count == 3);
string AIReply = "One possible solution is to use a more advanced natural language processing library like NLTK or sp";
Assert.That(reply.Trim() == AIReply);
}

public void TestPostChat(int num)
{
Assert.That(llm.GetChat().Count == num);
}
}
}
Loading
Loading