diff --git a/CHANGELOG.md b/CHANGELOG.md index 657b9912..a0ca5430 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## v1.2.2 +#### 🐛 Fixes + +- use namespaces in all classes (PR: #104) +- await separately in StartServer (PR: #107) + + ## v1.2.1 #### 🐛 Fixes diff --git a/CHANGELOG.release.md b/CHANGELOG.release.md index 8c34aca6..75324372 100644 --- a/CHANGELOG.release.md +++ b/CHANGELOG.release.md @@ -1,5 +1,5 @@ ### 🐛 Fixes -- Kill server after Unity crash (PR: #101) -- Persist chat template on remote servers (PR: #103) +- use namespaces in all classes (PR: #104) +- await separately in StartServer (PR: #107) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index fc1e512d..e590ec97 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -301,21 +301,30 @@ private async Task StartLLMServer() string GPUArgument = numGPULayers <= 0 ? "" : $" -ngl {numGPULayers}"; LLMUnitySetup.makeExecutable(server); - await RunAndWait(server, arguments + GPUArgument); + + RunServerCommand(server, arguments + GPUArgument); + if (asynchronousStartup) await WaitOneASync(serverBlock, TimeSpan.FromSeconds(60)); + else serverBlock.WaitOne(60000); if (process.HasExited && mmapCrash) { Debug.Log("Mmap error, fallback to no mmap use"); serverBlock.Reset(); arguments += " --no-mmap"; - await RunAndWait(server, arguments + GPUArgument); + + RunServerCommand(server, arguments + GPUArgument); + if (asynchronousStartup) await WaitOneASync(serverBlock, TimeSpan.FromSeconds(60)); + else serverBlock.WaitOne(60000); } if (process.HasExited && numGPULayers > 0) { Debug.Log("GPU failed, fallback to CPU"); serverBlock.Reset(); - await RunAndWait(server, arguments); + + RunServerCommand(server, arguments); + if (asynchronousStartup) await WaitOneASync(serverBlock, TimeSpan.FromSeconds(60)); + else serverBlock.WaitOne(60000); } if (process.HasExited) throw new Exception("Server could not be started!"); @@ -342,13 +351,6 @@ public void OnDestroy() StopProcess(); } - private async Task RunAndWait(string exe, string args, int seconds = 60) - { - RunServerCommand(exe, args); - if (asynchronousStartup) await WaitOneASync(serverBlock, TimeSpan.FromSeconds(seconds)); - else serverBlock.WaitOne(seconds * 1000); - } - /// Wrapper from https://stackoverflow.com/a/18766131 private static Task WaitOneASync(WaitHandle handle, TimeSpan timeout) { diff --git a/Runtime/LLMGGUF.cs b/Runtime/LLMGGUF.cs index 213a0843..8171e817 100644 --- a/Runtime/LLMGGUF.cs +++ b/Runtime/LLMGGUF.cs @@ -3,221 +3,224 @@ using System.IO; using System.Runtime.InteropServices; -public enum GGUFValueType +namespace LLMUnity { - UINT8 = 0, - INT8 = 1, - UINT16 = 2, - INT16 = 3, - UINT32 = 4, - INT32 = 5, - FLOAT32 = 6, - BOOL = 7, - STRING = 8, - ARRAY = 9, - UINT64 = 10, - INT64 = 11, - FLOAT64 = 12 -} - -public class ReaderField -{ - public int offset; - public string name; - public List parts = new List(); - public List data = new List(); - public List types = new List(); -} - -public class ReaderTensor -{ - public string name; - public GGUFValueType tensor_type; - public uint[] shape; - public int n_elements; - public int n_bytes; - public int data_offset; - public Array data; - public ReaderField field; -} - -public class GGUFReader -{ - private const uint GGUF_MAGIC = 0x46554747; // "GGUF" - private const int GGUF_VERSION = 3; - private readonly List READER_SUPPORTED_VERSIONS = new List { 2, GGUF_VERSION }; - private Dictionary gguf_scalar_to_np = new Dictionary - { - { GGUFValueType.UINT8, typeof(byte) }, - { GGUFValueType.INT8, typeof(sbyte) }, - { GGUFValueType.UINT16, typeof(ushort) }, - { GGUFValueType.INT16, typeof(short) }, - { GGUFValueType.UINT32, typeof(uint) }, - { GGUFValueType.INT32, typeof(int) }, - { GGUFValueType.FLOAT32, typeof(float) }, - { GGUFValueType.UINT64, typeof(ulong) }, - { GGUFValueType.INT64, typeof(long) }, - { GGUFValueType.FLOAT64, typeof(double) }, - { GGUFValueType.BOOL, typeof(bool) } - }; - - // private MemoryStream data; - private FileStream data; - - public Dictionary fields = new Dictionary(); - public List tensors = new List(); - - public GGUFReader(string path, string mode = "r") + public enum GGUFValueType { - // data = new MemoryStream(File.ReadAllBytes(path)); - data = new FileStream(path, FileMode.Open, FileAccess.Read); - int offs = 0; - - if (BitConverter.ToUInt32(ReadBytes(offs, 4), 0) != GGUF_MAGIC) - throw new ArgumentException("GGUF magic invalid"); - offs += 4; - - uint temp_version = BitConverter.ToUInt32(ReadBytes(offs, 4)); - if ((temp_version & 65535) == 0) - { - byte[] tempBytes = ReadBytes(offs, 4); - Array.Reverse(tempBytes); - temp_version = BitConverter.ToUInt32(tempBytes, 0); - } - uint version = temp_version; - - if (!READER_SUPPORTED_VERSIONS.Contains((int)version)) - throw new ArgumentException($"Sorry, file appears to be version {version} which we cannot handle"); - - offs += PushField(new ReaderField { offset = offs, name = "GGUF.version", parts = new List { new uint[] { temp_version } }, data = new List { 0 }, types = new List { GGUFValueType.UINT32 } }); - ulong[] temp_counts = new ulong[2]; - Buffer.BlockCopy(ReadBytes(offs, 16), 0, temp_counts, 0, 16); - offs += PushField(new ReaderField { offset = offs, name = "GGUF.tensor_count", parts = new List { new ulong[] { temp_counts[0] } }, data = new List { 0 }, types = new List { GGUFValueType.UINT64 } }); - offs += PushField(new ReaderField { offset = offs, name = "GGUF.kv_count", parts = new List { new ulong[] { temp_counts[1] } }, data = new List { 0 }, types = new List { GGUFValueType.UINT64 } }); - ulong tensor_count = temp_counts[0]; - ulong kv_count = temp_counts[1]; - offs = BuildFields(offs, (int)kv_count); - data.Close(); + UINT8 = 0, + INT8 = 1, + UINT16 = 2, + INT16 = 3, + UINT32 = 4, + INT32 = 5, + FLOAT32 = 6, + BOOL = 7, + STRING = 8, + ARRAY = 9, + UINT64 = 10, + INT64 = 11, + FLOAT64 = 12 } - public ReaderField GetField(string key) + public class ReaderField { - if (fields.TryGetValue(key, out ReaderField value)) - return value; - return null; + public int offset; + public string name; + public List parts = new List(); + public List data = new List(); + public List types = new List(); } - private byte[] ReadBytes(int offset, int count) + public class ReaderTensor { - byte[] buffer = new byte[count]; - data.Seek(offset, SeekOrigin.Begin); - data.Read(buffer, 0, count); - return buffer; + public string name; + public GGUFValueType tensor_type; + public uint[] shape; + public int n_elements; + public int n_bytes; + public int data_offset; + public Array data; + public ReaderField field; } - private int PushField(ReaderField field, bool skip_sum = false) + public class GGUFReader { - if (fields.ContainsKey(field.name)) - throw new ArgumentException($"Duplicate {field.name} already in list at offset {field.offset}"); - fields[field.name] = field; - if (skip_sum) - return 0; - int sum = 0; - for (int i = 0; i < field.parts.Count; i++) + private const uint GGUF_MAGIC = 0x46554747; // "GGUF" + private const int GGUF_VERSION = 3; + private readonly List READER_SUPPORTED_VERSIONS = new List { 2, GGUF_VERSION }; + private Dictionary gguf_scalar_to_np = new Dictionary + { + { GGUFValueType.UINT8, typeof(byte) }, + { GGUFValueType.INT8, typeof(sbyte) }, + { GGUFValueType.UINT16, typeof(ushort) }, + { GGUFValueType.INT16, typeof(short) }, + { GGUFValueType.UINT32, typeof(uint) }, + { GGUFValueType.INT32, typeof(int) }, + { GGUFValueType.FLOAT32, typeof(float) }, + { GGUFValueType.UINT64, typeof(ulong) }, + { GGUFValueType.INT64, typeof(long) }, + { GGUFValueType.FLOAT64, typeof(double) }, + { GGUFValueType.BOOL, typeof(bool) } + }; + + // private MemoryStream data; + private FileStream data; + + public Dictionary fields = new Dictionary(); + public List tensors = new List(); + + public GGUFReader(string path, string mode = "r") { - Type partType = gguf_scalar_to_np[field.types[i]]; - sum += Marshal.SizeOf(partType) * field.parts[i].Length; + // data = new MemoryStream(File.ReadAllBytes(path)); + data = new FileStream(path, FileMode.Open, FileAccess.Read); + int offs = 0; + + if (BitConverter.ToUInt32(ReadBytes(offs, 4), 0) != GGUF_MAGIC) + throw new ArgumentException("GGUF magic invalid"); + offs += 4; + + uint temp_version = BitConverter.ToUInt32(ReadBytes(offs, 4)); + if ((temp_version & 65535) == 0) + { + byte[] tempBytes = ReadBytes(offs, 4); + Array.Reverse(tempBytes); + temp_version = BitConverter.ToUInt32(tempBytes, 0); + } + uint version = temp_version; + + if (!READER_SUPPORTED_VERSIONS.Contains((int)version)) + throw new ArgumentException($"Sorry, file appears to be version {version} which we cannot handle"); + + offs += PushField(new ReaderField { offset = offs, name = "GGUF.version", parts = new List { new uint[] { temp_version } }, data = new List { 0 }, types = new List { GGUFValueType.UINT32 } }); + ulong[] temp_counts = new ulong[2]; + Buffer.BlockCopy(ReadBytes(offs, 16), 0, temp_counts, 0, 16); + offs += PushField(new ReaderField { offset = offs, name = "GGUF.tensor_count", parts = new List { new ulong[] { temp_counts[0] } }, data = new List { 0 }, types = new List { GGUFValueType.UINT64 } }); + offs += PushField(new ReaderField { offset = offs, name = "GGUF.kv_count", parts = new List { new ulong[] { temp_counts[1] } }, data = new List { 0 }, types = new List { GGUFValueType.UINT64 } }); + ulong tensor_count = temp_counts[0]; + ulong kv_count = temp_counts[1]; + offs = BuildFields(offs, (int)kv_count); + data.Close(); } - return sum; - } - private (ulong[], byte[]) GetStr(int offset) - { - ulong slen = BitConverter.ToUInt64(ReadBytes(offset, 8)); - byte[] sdata = ReadBytes(offset + 8, (int)slen); - return (new ulong[] { slen }, sdata); - } + public ReaderField GetField(string key) + { + if (fields.TryGetValue(key, out ReaderField value)) + return value; + return null; + } - private (int, List, List, List) GetFieldParts(int orig_offs, int raw_type) - { - int offs = orig_offs; - List types = new List(); - types.Add((GGUFValueType)raw_type); - // Handle strings. - if ((GGUFValueType)raw_type == GGUFValueType.STRING) + private byte[] ReadBytes(int offset, int count) { - (ulong[] slen, byte[] sdata) = GetStr(offs); - List sparts = new List { slen, sdata }; - int size = slen.Length * sizeof(ulong) + sdata.Length; - return (size, sparts, new List { 1 }, types); + byte[] buffer = new byte[count]; + data.Seek(offset, SeekOrigin.Begin); + data.Read(buffer, 0, count); + return buffer; } - // Check if it's a simple scalar type. - if (gguf_scalar_to_np.TryGetValue((GGUFValueType)raw_type, out Type nptype)) + private int PushField(ReaderField field, bool skip_sum = false) { - Array val = ReadBytes(offs, Marshal.SizeOf(nptype)); - int size = nptype == typeof(bool) ? 1 : Marshal.SizeOf(nptype); - return (size, new List { val }, new List { 0 }, types); + if (fields.ContainsKey(field.name)) + throw new ArgumentException($"Duplicate {field.name} already in list at offset {field.offset}"); + fields[field.name] = field; + if (skip_sum) + return 0; + int sum = 0; + for (int i = 0; i < field.parts.Count; i++) + { + Type partType = gguf_scalar_to_np[field.types[i]]; + sum += Marshal.SizeOf(partType) * field.parts[i].Length; + } + return sum; } - // Handle arrays. - if ((GGUFValueType)raw_type == GGUFValueType.ARRAY) + private (ulong[], byte[]) GetStr(int offset) { - int raw_itype = BitConverter.ToInt32(ReadBytes(offs, 4)); - offs += Marshal.SizeOf(typeof(int)); + ulong slen = BitConverter.ToUInt64(ReadBytes(offset, 8)); + byte[] sdata = ReadBytes(offset + 8, (int)slen); + return (new ulong[] { slen }, sdata); + } - ulong alen = BitConverter.ToUInt64(ReadBytes(offs, 8)); - offs += Marshal.SizeOf(typeof(ulong)); + private (int, List, List, List) GetFieldParts(int orig_offs, int raw_type) + { + int offs = orig_offs; + List types = new List(); + types.Add((GGUFValueType)raw_type); + // Handle strings. + if ((GGUFValueType)raw_type == GGUFValueType.STRING) + { + (ulong[] slen, byte[] sdata) = GetStr(offs); + List sparts = new List { slen, sdata }; + int size = slen.Length * sizeof(ulong) + sdata.Length; + return (size, sparts, new List { 1 }, types); + } - List aparts = new List { BitConverter.GetBytes(raw_itype), BitConverter.GetBytes(alen) }; - List data_idxs = new List(); + // Check if it's a simple scalar type. + if (gguf_scalar_to_np.TryGetValue((GGUFValueType)raw_type, out Type nptype)) + { + Array val = ReadBytes(offs, Marshal.SizeOf(nptype)); + int size = nptype == typeof(bool) ? 1 : Marshal.SizeOf(nptype); + return (size, new List { val }, new List { 0 }, types); + } - for (int idx = 0; idx < (int)alen; idx++) + // Handle arrays. + if ((GGUFValueType)raw_type == GGUFValueType.ARRAY) { - (int curr_size, List curr_parts, List curr_idxs, List curr_types) = GetFieldParts(offs, raw_itype); - if (idx == 0) - types.AddRange(curr_types); - - int idxs_offs = aparts.Count; - aparts.AddRange(curr_parts); - data_idxs.AddRange(new List(curr_idxs.ConvertAll(i => i + idxs_offs))); - offs += curr_size; + int raw_itype = BitConverter.ToInt32(ReadBytes(offs, 4)); + offs += Marshal.SizeOf(typeof(int)); + + ulong alen = BitConverter.ToUInt64(ReadBytes(offs, 8)); + offs += Marshal.SizeOf(typeof(ulong)); + + List aparts = new List { BitConverter.GetBytes(raw_itype), BitConverter.GetBytes(alen) }; + List data_idxs = new List(); + + for (int idx = 0; idx < (int)alen; idx++) + { + (int curr_size, List curr_parts, List curr_idxs, List curr_types) = GetFieldParts(offs, raw_itype); + if (idx == 0) + types.AddRange(curr_types); + + int idxs_offs = aparts.Count; + aparts.AddRange(curr_parts); + data_idxs.AddRange(new List(curr_idxs.ConvertAll(i => i + idxs_offs))); + offs += curr_size; + } + return (offs - orig_offs, aparts, data_idxs, types); } - return (offs - orig_offs, aparts, data_idxs, types); + // We can't deal with this one. + throw new ArgumentException($"Unknown/unhandled field type {(GGUFValueType)raw_type}"); } - // We can't deal with this one. - throw new ArgumentException($"Unknown/unhandled field type {(GGUFValueType)raw_type}"); - } - private int BuildFields(int offs, int count) - { - for (int i = 0; i < count; i++) + private int BuildFields(int offs, int count) { - int orig_offs = offs; - (ulong[] kv_klen, byte[] kv_kdata) = GetStr(offs); - offs += Marshal.SizeOf(typeof(ulong)) + kv_kdata.Length; - - int raw_kv_type = BitConverter.ToInt32(ReadBytes(offs, 4)); - offs += Marshal.SizeOf(typeof(int)); - List parts = new List { kv_klen, kv_kdata, BitConverter.GetBytes(raw_kv_type) }; - List idxs_offs = new List { parts.Count }; - - (int field_size, List field_parts, List field_idxs, List field_types) = GetFieldParts(offs, raw_kv_type); - if (field_size == -1) - continue; - - parts.AddRange(field_parts); - ReaderField readerField = new ReaderField + for (int i = 0; i < count; i++) { - offset = orig_offs, - name = System.Text.Encoding.UTF8.GetString(kv_kdata), - parts = parts, - data = new List(field_idxs.ConvertAll(idx => idx + idxs_offs[0])), - types = field_types - }; - PushField(readerField, skip_sum: true); - offs += field_size; + int orig_offs = offs; + (ulong[] kv_klen, byte[] kv_kdata) = GetStr(offs); + offs += Marshal.SizeOf(typeof(ulong)) + kv_kdata.Length; + + int raw_kv_type = BitConverter.ToInt32(ReadBytes(offs, 4)); + offs += Marshal.SizeOf(typeof(int)); + List parts = new List { kv_klen, kv_kdata, BitConverter.GetBytes(raw_kv_type) }; + List idxs_offs = new List { parts.Count }; + + (int field_size, List field_parts, List field_idxs, List field_types) = GetFieldParts(offs, raw_kv_type); + if (field_size == -1) + continue; + + parts.AddRange(field_parts); + ReaderField readerField = new ReaderField + { + offset = orig_offs, + name = System.Text.Encoding.UTF8.GetString(kv_kdata), + parts = parts, + data = new List(field_idxs.ConvertAll(idx => idx + idxs_offs[0])), + types = field_types + }; + PushField(readerField, skip_sum: true); + offs += field_size; + } + return offs; } - return offs; } } diff --git a/Samples~/ServerClient/ServerClient.cs b/Samples~/ServerClient/ServerClient.cs index e2679827..e6ddcd3a 100644 --- a/Samples~/ServerClient/ServerClient.cs +++ b/Samples~/ServerClient/ServerClient.cs @@ -3,76 +3,79 @@ using UnityEngine.UI; -public class ServerClientInteraction +namespace LLMUnitySamples { - InputField playerText; - Text AIText; - LLMClient llm; - - public ServerClientInteraction(InputField playerText, Text AIText, LLMClient llm) + public class ServerClientInteraction { - this.playerText = playerText; - this.AIText = AIText; - this.llm = llm; - } + InputField playerText; + Text AIText; + LLMClient llm; - public void Start() - { - playerText.onSubmit.AddListener(onInputFieldSubmit); - playerText.Select(); - } + public ServerClientInteraction(InputField playerText, Text AIText, LLMClient llm) + { + this.playerText = playerText; + this.AIText = AIText; + this.llm = llm; + } - public void onInputFieldSubmit(string message) - { - playerText.interactable = false; - AIText.text = "..."; - _ = llm.Chat(message, SetAIText, AIReplyComplete); - } + public void Start() + { + playerText.onSubmit.AddListener(onInputFieldSubmit); + playerText.Select(); + } - public void SetAIText(string text) - { - AIText.text = text; - } + public void onInputFieldSubmit(string message) + { + playerText.interactable = false; + AIText.text = "..."; + _ = llm.Chat(message, SetAIText, AIReplyComplete); + } - public void AIReplyComplete() - { - playerText.interactable = true; - playerText.Select(); - playerText.text = ""; + public void SetAIText(string text) + { + AIText.text = text; + } + + public void AIReplyComplete() + { + playerText.interactable = true; + playerText.Select(); + playerText.text = ""; + } } -} -public class ServerClient : MonoBehaviour -{ - public LLM llm; - public InputField playerText1; - public Text AIText1; - ServerClientInteraction interaction1; + public class ServerClient : MonoBehaviour + { + public LLM llm; + public InputField playerText1; + public Text AIText1; + ServerClientInteraction interaction1; - public LLMClient llmClient; - public InputField playerText2; - public Text AIText2; - ServerClientInteraction interaction2; + public LLMClient llmClient; + public InputField playerText2; + public Text AIText2; + ServerClientInteraction interaction2; - void Start() - { - interaction1 = new ServerClientInteraction(playerText1, AIText1, llm); - interaction2 = new ServerClientInteraction(playerText2, AIText2, llmClient); - interaction1.Start(); - interaction2.Start(); - } + void Start() + { + interaction1 = new ServerClientInteraction(playerText1, AIText1, llm); + interaction2 = new ServerClientInteraction(playerText2, AIText2, llmClient); + interaction1.Start(); + interaction2.Start(); + } - public void CancelRequests() - { - llm.CancelRequests(); - llmClient.CancelRequests(); - interaction1.AIReplyComplete(); - interaction2.AIReplyComplete(); - } + public void CancelRequests() + { + llm.CancelRequests(); + llmClient.CancelRequests(); + interaction1.AIReplyComplete(); + interaction2.AIReplyComplete(); + } - public void ExitGame() - { - Debug.Log("Exit button clicked"); - Application.Quit(); + public void ExitGame() + { + Debug.Log("Exit button clicked"); + Application.Quit(); + } } } diff --git a/Samples~/SimpleInteraction/SimpleInteraction.cs b/Samples~/SimpleInteraction/SimpleInteraction.cs index 12ffc5e8..005778b8 100644 --- a/Samples~/SimpleInteraction/SimpleInteraction.cs +++ b/Samples~/SimpleInteraction/SimpleInteraction.cs @@ -2,46 +2,49 @@ using LLMUnity; using UnityEngine.UI; -public class SimpleInteraction : MonoBehaviour +namespace LLMUnitySamples { - public LLM llm; - public InputField playerText; - public Text AIText; - - void Start() + public class SimpleInteraction : MonoBehaviour { - playerText.onSubmit.AddListener(onInputFieldSubmit); - playerText.Select(); - } + public LLM llm; + public InputField playerText; + public Text AIText; - void onInputFieldSubmit(string message) - { - playerText.interactable = false; - AIText.text = "..."; - _ = llm.Chat(message, SetAIText, AIReplyComplete); - } + void Start() + { + playerText.onSubmit.AddListener(onInputFieldSubmit); + playerText.Select(); + } - public void SetAIText(string text) - { - AIText.text = text; - } + void onInputFieldSubmit(string message) + { + playerText.interactable = false; + AIText.text = "..."; + _ = llm.Chat(message, SetAIText, AIReplyComplete); + } - public void AIReplyComplete() - { - playerText.interactable = true; - playerText.Select(); - playerText.text = ""; - } + public void SetAIText(string text) + { + AIText.text = text; + } - public void CancelRequests() - { - llm.CancelRequests(); - AIReplyComplete(); - } + public void AIReplyComplete() + { + playerText.interactable = true; + playerText.Select(); + playerText.text = ""; + } - public void ExitGame() - { - Debug.Log("Exit button clicked"); - Application.Quit(); + public void CancelRequests() + { + llm.CancelRequests(); + AIReplyComplete(); + } + + public void ExitGame() + { + Debug.Log("Exit button clicked"); + Application.Quit(); + } } } diff --git a/VERSION b/VERSION index 6a5e98a7..cc904638 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.2.1 +v1.2.2 diff --git a/package.json b/package.json index 3b3d24dc..0ddab4ac 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "ai.undream.llmunity", - "version": "1.2.1", + "version": "1.2.2", "displayName": "LLMUnity", "description": "LLMUnity allows to run and distribute LLM models in the Unity engine.", "unity": "2022.3",