diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs index 93fa3e6e..c4313429 100644 --- a/Runtime/LLMChatTemplates.cs +++ b/Runtime/LLMChatTemplates.cs @@ -7,9 +7,6 @@ namespace LLMUnity { public abstract class ChatTemplate { - public string playerName; - public string AIName; - public static string DefaultTemplate; public static Type[] templateClasses; public static Dictionary templates; @@ -39,7 +36,7 @@ static ChatTemplate() chatTemplates = new Dictionary(); foreach (Type templateClass in templateClasses) { - ChatTemplate template = (ChatTemplate)Activator.CreateInstance(templateClass, "", ""); + ChatTemplate template = (ChatTemplate)Activator.CreateInstance(templateClass); if (templates.ContainsKey(template.GetName())) Debug.LogError($"{template.GetName()} already in templates"); templates[template.GetName()] = templateClass; if (templatesDescription.ContainsKey(template.GetDescription())) Debug.LogError($"{template.GetDescription()} already in templatesDescription"); @@ -57,12 +54,6 @@ static ChatTemplate() } } - public ChatTemplate(string playerName = "user", string AIName = "assistant") - { - this.playerName = playerName; - this.AIName = AIName; - } - public static string FromName(string name) { string nameLower = name.ToLower(); @@ -108,28 +99,28 @@ public static string FromGGUF(string path) return DefaultTemplate; } - public static ChatTemplate GetTemplate(string template, string playerName, string AIName) + public static ChatTemplate GetTemplate(string template) { - return (ChatTemplate)Activator.CreateInstance(templates[template], playerName, AIName); + return (ChatTemplate)Activator.CreateInstance(templates[template]); } public abstract string GetName(); public abstract string GetDescription(); public virtual string[] GetNameMatches() { return new string[] {}; } public virtual string[] GetChatTemplateMatches() { return new string[] {}; } - public abstract string[] GetStop(); + public abstract string[] GetStop(string playerName, string AIName); protected virtual string PromptPrefix() { return ""; } protected virtual string SystemPrefix() { return ""; } protected virtual string SystemSuffix() { return ""; } - protected virtual string PlayerPrefix() { return ""; } - protected virtual string AIPrefix() { return ""; } + protected virtual string PlayerPrefix(string playerName) { return ""; } + protected virtual string AIPrefix(string AIName) { return ""; } protected virtual string PrefixMessageSeparator() { return ""; } protected virtual string RequestPrefix() { return ""; } protected virtual string RequestSuffix() { return ""; } protected virtual string PairSuffix() { return ""; } - public virtual string ComputePrompt(List messages) + public virtual string ComputePrompt(List messages, string AIName) { string chatPrompt = PromptPrefix(); string systemPrompt = ""; @@ -143,21 +134,13 @@ public virtual string ComputePrompt(List messages) { chatPrompt += RequestPrefix(); if (i == 1 && systemPrompt != "") chatPrompt += systemPrompt; - if (messages[i].role != playerName) - { - Debug.Log($"Role was {messages[i].role}, was expecting {playerName}"); - } - chatPrompt += PlayerPrefix() + PrefixMessageSeparator() + messages[i].content + RequestSuffix(); + chatPrompt += PlayerPrefix(messages[i].role) + PrefixMessageSeparator() + messages[i].content + RequestSuffix(); if (i < messages.Count - 1) { - if (messages[i + 1].role != AIName) - { - Debug.Log($"Role was {messages[i + 1].role}, was expecting {AIName}"); - } - chatPrompt += AIPrefix() + PrefixMessageSeparator() + messages[i + 1].content + PairSuffix(); + chatPrompt += AIPrefix(messages[i + 1].role) + PrefixMessageSeparator() + messages[i + 1].content + PairSuffix(); } } - chatPrompt += AIPrefix(); + chatPrompt += AIPrefix(AIName); return chatPrompt; } @@ -175,8 +158,6 @@ public string[] AddStopNewlines(string[] stop) public class ChatMLTemplate : ChatTemplate { - public ChatMLTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "chatml"; } public override string GetDescription() { return "chatml (best overall)"; } public override string[] GetNameMatches() { return new string[] {"chatml", "hermes"}; } @@ -184,12 +165,12 @@ public ChatMLTemplate(string playerName = "user", string AIName = "assistant") : protected override string SystemPrefix() { return "<|im_start|>system\n"; } protected override string SystemSuffix() { return "<|im_end|>\n"; } - protected override string PlayerPrefix() { return $"<|im_start|>{playerName}\n"; } - protected override string AIPrefix() { return $"<|im_start|>{AIName}\n"; } + protected override string PlayerPrefix(string playerName) { return $"<|im_start|>{playerName}\n"; } + protected override string AIPrefix(string AIName) { return $"<|im_start|>{AIName}\n"; } protected override string RequestSuffix() { return "<|im_end|>\n"; } protected override string PairSuffix() { return "<|im_end|>\n"; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { "<|im_start|>", "<|im_end|>" }); } @@ -197,8 +178,6 @@ public override string[] GetStop() public class LLama2Template : ChatTemplate { - public LLama2Template(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "llama"; } public override string GetDescription() { return "llama"; } @@ -208,7 +187,7 @@ public LLama2Template(string playerName = "user", string AIName = "assistant") : protected override string RequestSuffix() { return " [/INST]"; } protected override string PairSuffix() { return " "; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { "[INST]", "[/INST]" }); } @@ -216,17 +195,15 @@ public override string[] GetStop() public class LLama2ChatTemplate : LLama2Template { - public LLama2ChatTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "llama chat"; } public override string GetDescription() { return "llama (modified for chat)"; } public override string[] GetNameMatches() { return new string[] {"llama"}; } - protected override string PlayerPrefix() { return "### " + playerName + ":"; } - protected override string AIPrefix() { return "### " + AIName + ":"; } + protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; } + protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; } protected override string PrefixMessageSeparator() { return " "; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { "[INST]", "[/INST]", "###" }); } @@ -234,8 +211,6 @@ public override string[] GetStop() public class MistralInstructTemplate : ChatTemplate { - public MistralInstructTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "mistral instruct"; } public override string GetDescription() { return "mistral instruct"; } @@ -246,7 +221,7 @@ public MistralInstructTemplate(string playerName = "user", string AIName = "assi protected override string RequestSuffix() { return " [/INST]"; } protected override string PairSuffix() { return ""; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { "[INST]", "[/INST]" }); } @@ -254,18 +229,16 @@ public override string[] GetStop() public class MistralChatTemplate : MistralInstructTemplate { - public MistralChatTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "mistral chat"; } public override string GetDescription() { return "mistral (modified for chat)"; } public override string[] GetNameMatches() { return new string[] {"mistral"}; } public override string[] GetChatTemplateMatches() { return new string[] {"{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"}; } - protected override string PlayerPrefix() { return "### " + playerName + ":"; } - protected override string AIPrefix() { return "### " + AIName + ":"; } + protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; } + protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; } protected override string PrefixMessageSeparator() { return " "; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { "[INST]", "[/INST]", "###" }); } @@ -273,20 +246,18 @@ public override string[] GetStop() public class AlpacaTemplate : ChatTemplate { - public AlpacaTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "alpaca"; } public override string GetDescription() { return "alpaca (best alternative)"; } public override string[] GetNameMatches() { return new string[] {"alpaca"}; } protected override string SystemSuffix() { return "\n\n"; } protected override string RequestSuffix() { return "\n"; } - protected override string PlayerPrefix() { return "### " + playerName + ":"; } - protected override string AIPrefix() { return "### " + AIName + ":"; } + protected override string PlayerPrefix(string playerName) { return "### " + playerName + ":"; } + protected override string AIPrefix(string AIName) { return "### " + AIName + ":"; } protected override string PrefixMessageSeparator() { return " "; } protected override string PairSuffix() { return "\n"; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { "###" }); } @@ -294,20 +265,18 @@ public override string[] GetStop() public class Phi2Template : ChatTemplate { - public Phi2Template(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "phi"; } public override string GetDescription() { return "phi"; } public override string[] GetNameMatches() { return new string[] {"phi"}; } protected override string SystemSuffix() { return "\n\n"; } protected override string RequestSuffix() { return "\n"; } - protected override string PlayerPrefix() { return playerName + ":"; } - protected override string AIPrefix() { return AIName + ":"; } + protected override string PlayerPrefix(string playerName) { return playerName + ":"; } + protected override string AIPrefix(string AIName) { return AIName + ":"; } protected override string PrefixMessageSeparator() { return " "; } protected override string PairSuffix() { return "\n"; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { playerName + ":", AIName + ":" }); } @@ -315,8 +284,6 @@ public override string[] GetStop() public class ZephyrTemplate : ChatTemplate { - public ZephyrTemplate(string playerName = "user", string AIName = "assistant") : base(playerName, AIName) {} - public override string GetName() { return "zephyr"; } public override string GetDescription() { return "zephyr"; } public override string[] GetNameMatches() { return new string[] {"zephyr"}; } @@ -324,12 +291,12 @@ public ZephyrTemplate(string playerName = "user", string AIName = "assistant") : protected override string SystemPrefix() { return "<|system|>\n"; } protected override string SystemSuffix() { return "\n"; } - protected override string PlayerPrefix() { return $"<|user|>\n"; } - protected override string AIPrefix() { return $"<|assistant|>\n"; } + protected override string PlayerPrefix(string playerName) { return $"<|user|>\n"; } + protected override string AIPrefix(string AIName) { return $"<|assistant|>\n"; } protected override string RequestSuffix() { return "\n"; } protected override string PairSuffix() { return "\n"; } - public override string[] GetStop() + public override string[] GetStop(string playerName, string AIName) { return AddStopNewlines(new string[] { $"<|user|>", $"<|assistant|>" }); } diff --git a/Runtime/LLMClient.cs b/Runtime/LLMClient.cs index a04c1f60..b460a705 100644 --- a/Runtime/LLMClient.cs +++ b/Runtime/LLMClient.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; @@ -88,7 +89,6 @@ public class LLMClient : MonoBehaviour private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; private string previousEndpoint; public bool setNKeepToPrompt = true; - private List stopAll; private List WIPRequests = new List(); static object chatPromptLock = new object(); static object chatAddLock = new object(); @@ -117,7 +117,7 @@ public LLM GetServer() public virtual void SetTemplate(string templateName) { chatTemplate = templateName; - template = ChatTemplate.GetTemplate(templateName, playerName, AIName); + LoadTemplate(); } private void Reset() @@ -197,10 +197,7 @@ private void SetNKeep(List tokens) private void LoadTemplate() { - template = ChatTemplate.GetTemplate(chatTemplate, playerName, AIName); - stopAll = new List(); - stopAll.AddRange(template.GetStop()); - if (stop != null) stopAll.AddRange(stop); + template = ChatTemplate.GetTemplate(chatTemplate); } #if UNITY_EDITOR @@ -210,6 +207,13 @@ public async void SetGrammar(string path) } #endif + List GetStopwords() + { + List stopAll = new List(template.GetStop(playerName, AIName)); + if (stop != null) stopAll.AddRange(stop); + return stopAll; + } + public ChatRequest GenerateRequest(string prompt) { // setup the request struct @@ -222,7 +226,7 @@ public ChatRequest GenerateRequest(string prompt) chatRequest.n_predict = numPredict; chatRequest.n_keep = nKeep; chatRequest.stream = stream; - chatRequest.stop = stopAll; + chatRequest.stop = GetStopwords(); chatRequest.tfs_z = tfsZ; chatRequest.typical_p = typicalP; chatRequest.repeat_penalty = repeatPenalty; @@ -298,7 +302,7 @@ public async Task Chat(string question, Callback callback = null string json; lock (chatPromptLock) { AddPlayerMessage(question); - string prompt = template.ComputePrompt(chat); + string prompt = template.ComputePrompt(chat, AIName); json = JsonUtility.ToJson(GenerateRequest(prompt)); chat.RemoveAt(chat.Count - 1); } diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 33b14aeb..535f91ae 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -89,7 +89,13 @@ public async Task RunTests() TestInitParameters((await llm.Tokenize(prompt)).Count, 1); TestWarmup(); await llm.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat); - TestPostChat(); + TestPostChat(3); + llm.SetPrompt(llm.prompt); + llm.AIName = "False response"; + await llm.Chat("How can I increase my meme production/output? Currently, I only create them in ancient babylonian which is time consuming.", TestChat2); + TestPostChat(3); + await llm.Chat("bye!"); + TestPostChat(5); prompt = "How are you?"; llm.SetPrompt(prompt); await llm.Chat("hi"); @@ -122,7 +128,7 @@ public void TestAlive() public void TestInitParameters(int nkeep, int chats) { Assert.That(llm.nKeep == nkeep); - Assert.That(llm.template.GetStop().Length > 0); + Assert.That(llm.template.GetStop(llm.playerName, llm.AIName).Length > 0); Assert.That(llm.GetChat().Count == chats); } @@ -142,9 +148,15 @@ public void TestChat(string reply) Assert.That(reply.Trim() == AIReply); } - public void TestPostChat() + public void TestChat2(string reply) { - Assert.That(llm.GetChat().Count == 3); + string AIReply = "One possible solution is to use a more advanced natural language processing library like NLTK or sp"; + Assert.That(reply.Trim() == AIReply); + } + + public void TestPostChat(int num) + { + Assert.That(llm.GetChat().Count == num); } } } diff --git a/Tests/Runtime/TestLLMChatTemplates.cs b/Tests/Runtime/TestLLMChatTemplates.cs index 8fa14d53..dc2f360e 100644 --- a/Tests/Runtime/TestLLMChatTemplates.cs +++ b/Tests/Runtime/TestLLMChatTemplates.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using NUnit.Framework.Internal; using NUnit.Framework; -using UnityEngine; namespace LLMUnityTests { @@ -22,7 +21,7 @@ public class TestChatTemplate public void TestChatML() { Assert.AreEqual( - new ChatMLTemplate().ComputePrompt(messages), + new ChatMLTemplate().ComputePrompt(messages, "assistant"), "<|im_start|>system\nyou are a bot<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n<|im_start|>assistant\nchat template is awesome<|im_end|>\n<|im_start|>user\ndo you think so?<|im_end|>\n<|im_start|>assistant\n" ); } @@ -31,7 +30,7 @@ public void TestChatML() public void TestMistralInstruct() { Assert.AreEqual( - new MistralInstructTemplate().ComputePrompt(messages), + new MistralInstructTemplate().ComputePrompt(messages, "assistant"), "[INST] you are a bot\n\nHello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]chat template is awesome[INST] do you think so? [/INST]" ); } @@ -40,7 +39,7 @@ public void TestMistralInstruct() public void TestMistralChat() { Assert.AreEqual( - new MistralChatTemplate().ComputePrompt(messages), + new MistralChatTemplate().ComputePrompt(messages, "assistant"), "[INST] you are a bot\n\n### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today?[INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome[INST] ### user: do you think so? [/INST]### assistant:" ); } @@ -49,7 +48,7 @@ public void TestMistralChat() public void TestLLama2() { Assert.AreEqual( - new LLama2Template().ComputePrompt(messages), + new LLama2Template().ComputePrompt(messages, "assistant"), "[INST] <>\nyou are a bot\n<> Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]chat template is awesome [INST] do you think so? [/INST]" ); } @@ -58,7 +57,7 @@ public void TestLLama2() public void TestLLama2Chat() { Assert.AreEqual( - new LLama2ChatTemplate().ComputePrompt(messages), + new LLama2ChatTemplate().ComputePrompt(messages, "assistant"), "[INST] <>\nyou are a bot\n<> ### user: Hello, how are you? [/INST]### assistant: I'm doing great. How can I help you today? [INST] ### user: I'd like to show off how chat templating works! [/INST]### assistant: chat template is awesome [INST] ### user: do you think so? [/INST]### assistant:" ); } @@ -67,7 +66,7 @@ public void TestLLama2Chat() public void TestAlpaca() { Assert.AreEqual( - new AlpacaTemplate().ComputePrompt(messages), + new AlpacaTemplate().ComputePrompt(messages, "assistant"), "you are a bot\n\n### user: Hello, how are you?\n### assistant: I'm doing great. How can I help you today?\n### user: I'd like to show off how chat templating works!\n### assistant: chat template is awesome\n### user: do you think so?\n### assistant:" ); } @@ -76,7 +75,7 @@ public void TestAlpaca() public void TestPhi2() { Assert.AreEqual( - new Phi2Template().ComputePrompt(messages), + new Phi2Template().ComputePrompt(messages, "assistant"), "you are a bot\n\nuser: Hello, how are you?\nassistant: I'm doing great. How can I help you today?\nuser: I'd like to show off how chat templating works!\nassistant: chat template is awesome\nuser: do you think so?\nassistant:" ); } @@ -85,7 +84,7 @@ public void TestPhi2() public void TestZephyr() { Assert.AreEqual( - new ZephyrTemplate().ComputePrompt(messages), + new ZephyrTemplate().ComputePrompt(messages, "assistant"), "<|system|>\nyou are a bot\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n<|assistant|>\nchat template is awesome\n<|user|>\ndo you think so?\n<|assistant|>\n" ); }