Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release v1.2.1 #102

Merged
merged 19 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## v1.2.1
#### 🐛 Fixes

- Kill server after Unity crash (PR: #101)
- Persist chat template on remote servers (PR: #103)


## v1.2.0
#### 🚀 Features

Expand Down
15 changes: 2 additions & 13 deletions CHANGELOG.release.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
### 🚀 Features

- LLM server unit tests (PR: #90)
- Implement chat templates (PR: #92)
- Stop chat functionality (PR: #95)
- Keep only the llamafile binary (PR: #97)

### 🐛 Fixes

- Fix remote server functionality (PR: #96)
- Fix Max issue needing to run llamafile manually the first time (PR: #98)

### 📦 General

- Async startup support (PR: #89)
- Kill server after Unity crash (PR: #101)
- Persist chat template on remote servers (PR: #103)

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,10 @@ If it is not selected, the full reply from the model is received in one go

- `Parallel Prompts` number of prompts that can happen in parallel (default: -1 = number of LLM/LLMClient objects)
- `Debug` select to log the output of the model in the Unity Editor
- `Asynchronous Startup` allows to start the server asynchronously
- `Remote` select to allow remote access to the server
- `Port` port to run the server
- `Kill Existing Servers On Start` kills existing servers by the Unity project on startup to handle Unity crashes

</details>

Expand Down
45 changes: 24 additions & 21 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Diagnostics;
using System.IO;
using System.IO.Compression;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using UnityEditor;
Expand All @@ -23,6 +22,7 @@ public class LLM : LLMClient
[ServerAdvanced] public bool debug = false;
[ServerAdvanced] public bool asynchronousStartup = false;
[ServerAdvanced] public bool remote = false;
[ServerAdvanced] public bool killExistingServersOnStart = true;

[Model] public string model = "";
[ModelAddonAdvanced] public string lora = "";
Expand Down Expand Up @@ -52,6 +52,8 @@ public class LLM : LLMClient
private bool mmapCrash = false;
public bool serverListening { get; private set; } = false;
private ManualResetEvent serverBlock = new ManualResetEvent(false);
static object crashKillLock = new object();
static bool crashKill = false;

#if UNITY_EDITOR
[InitializeOnLoadMethod]
Expand Down Expand Up @@ -162,11 +164,19 @@ public List<LLMClient> GetListeningClients()
return clients;
}

void KillServersAfterUnityCrash()
{
lock (crashKillLock) {
if (crashKill) return;
LLMUnitySetup.KillServerAfterUnityCrash(server);
crashKill = true;
}
}

new public async void Awake()
{
// start the llm server and run the Awake of the client
if (killExistingServersOnStart) KillServersAfterUnityCrash();
await StartLLMServer();

base.Awake();
}

Expand All @@ -189,20 +199,6 @@ private string SelectApeBinary()
return apeExe;
}

public bool IsPortInUse()
{
try
{
using (TcpClient c = new TcpClient())
{
c.Connect(host, port);
}
return true;
}
catch {}
return false;
}

private void DebugLog(string message, bool logError = false)
{
// Debug log if debug is enabled
Expand Down Expand Up @@ -282,7 +278,8 @@ private void RunServerCommand(string exe, string args)

private async Task StartLLMServer()
{
if (IsPortInUse()) throw new Exception($"Port {port} is already in use, please use another port or kill all llamafile processes using it!");
bool portInUse = asynchronousStartup ? await IsServerReachableAsync() : IsServerReachable();
if (portInUse) throw new Exception($"Port {port} is already in use, please use another port or kill all llamafile processes using it!");

// Start the LLM server in a cross-platform way
if (model == "") throw new Exception("No model file provided!");
Expand Down Expand Up @@ -322,15 +319,21 @@ private async Task StartLLMServer()
}

if (process.HasExited) throw new Exception("Server could not be started!");
else LLMUnitySetup.SaveServerPID(process.Id);
}

public void StopProcess()
{
// kill the llm server
if (process != null && !process.HasExited)
if (process != null)
{
process.Kill();
process.WaitForExit();
int pid = process.Id;
if (!process.HasExited)
{
process.Kill();
process.WaitForExit();
}
LLMUnitySetup.DeleteServerPID(pid);
}
}

Expand Down
47 changes: 37 additions & 10 deletions Runtime/LLMClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class LLMClient : MonoBehaviour
[TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";

protected List<ChatMessage> chat;
public string chatTemplate;
public string chatTemplate = ChatTemplate.DefaultTemplate;
public ChatTemplate template;
private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
private string previousEndpoint;
Expand All @@ -95,7 +95,6 @@ public class LLMClient : MonoBehaviour

public async void Awake()
{
// initialise the prompt and set the keep tokens based on its length
InitGrammar();
await InitPrompt();
LoadTemplate();
Expand Down Expand Up @@ -131,18 +130,13 @@ private void OnValidate()
string newEndpoint = host + ":" + port;
if (newEndpoint != previousEndpoint)
{
string template = ChatTemplate.DefaultTemplate;
string templateToSet = chatTemplate;
if (GetType() == typeof(LLMClient))
{
LLM server = GetServer();
if (server != null) template = server.chatTemplate;
if (server != null) templateToSet = server.chatTemplate;
}
else
{
if (chatTemplate != null && chatTemplate != "")
template = chatTemplate;
}
SetTemplate(template);
SetTemplate(templateToSet);
previousEndpoint = newEndpoint;
}
}
Expand Down Expand Up @@ -375,6 +369,39 @@ public void CancelRequests()
WIPRequests.Clear();
}

public bool IsServerReachable(int timeout = 5)
{
using (UnityWebRequest webRequest = UnityWebRequest.Head($"{host}:{port}/tokenize"))
{
webRequest.timeout = timeout;
webRequest.SendWebRequest();
while (!webRequest.isDone) {}
if (webRequest.result == UnityWebRequest.Result.ConnectionError)
{
return false;
}
return true;
}
}

public async Task<bool> IsServerReachableAsync(int timeout = 5)
{
using (UnityWebRequest webRequest = UnityWebRequest.Head($"{host}:{port}/tokenize"))
{
webRequest.timeout = timeout;
webRequest.SendWebRequest();
while (!webRequest.isDone)
{
await Task.Yield();
}
if (webRequest.result == UnityWebRequest.Result.ConnectionError)
{
return false;
}
return true;
}
}

public async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
{
// send a post request to the server and call the relevant callbacks to convert the received content and handle it
Expand Down
114 changes: 111 additions & 3 deletions Runtime/LLMUnitySetup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
using Debug = UnityEngine.Debug;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.IO.Compression;
using System.Net;
using System;

namespace LLMUnity
{
Expand Down Expand Up @@ -98,7 +98,7 @@ public void DownloadProgressChanged(object sender, DownloadProgressChangedEventA
public static async Task DownloadFile(
string fileUrl, string savePath, bool overwrite = false, bool executable = false,
TaskCallback<string> callback = null, Callback<float> progresscallback = null,
bool async=true
bool async = true
)
{
// download a file to the specified path
Expand All @@ -117,7 +117,9 @@ public static async Task DownloadFile(
if (async)
{
await client.DownloadFileTaskAsync(fileUrl, tmpPath);
} else {
}
else
{
client.DownloadFile(fileUrl, tmpPath);
}
if (executable) makeExecutable(tmpPath);
Expand Down Expand Up @@ -164,6 +166,112 @@ await Task.Run(() =>
}
return fullPath.Substring(basePathSlash.Length + 1);
}

#endif

static string GetPIDFile()
{
string persistDir = Path.Combine(Application.persistentDataPath, "LLMUnity");
if (!Directory.Exists(persistDir))
{
Directory.CreateDirectory(persistDir);
}
return Path.Combine(persistDir, "server_process.txt");
}

public static void SaveServerPID(int pid)
{
try
{
using (StreamWriter writer = new StreamWriter(GetPIDFile(), true))
{
writer.WriteLine(pid);
}
}
catch (Exception e)
{
Debug.LogError("Error saving PID to file: " + e.Message);
}
}

static List<int> ReadServerPIDs()
{
List<int> pids = new List<int>();
string pidfile = GetPIDFile();
if (!File.Exists(pidfile)) return pids;

try
{
using (StreamReader reader = new StreamReader(pidfile))
{
string line;
while ((line = reader.ReadLine()) != null)
{
if (int.TryParse(line, out int pid))
{
pids.Add(pid);
}
else
{
Debug.LogError("Invalid file entry: " + line);
}
}
}
}
catch (Exception e)
{
Debug.LogError("Error reading from file: " + e.Message);
}
return pids;
}

public static string GetCommandLineArguments(Process process)
{
if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
{
return process.MainModule.FileName.Replace('\\', '/');
}
else
{
return RunProcess("ps", $"-o command -p {process.Id}").Replace("COMMAND\n", "");
}
}

public static void KillServerAfterUnityCrash(string serverBinary)
{
foreach (int pid in ReadServerPIDs())
{
try
{
Process process = Process.GetProcessById(pid);
string command = GetCommandLineArguments(process);
if (command.Contains(serverBinary))
{
Debug.Log($"killing existing server with {pid}: {command}");
process.Kill();
process.WaitForExit();
}
}
catch (Exception) {}
}

string pidfile = GetPIDFile();
if (File.Exists(pidfile)) File.Delete(pidfile);
}

public static void DeleteServerPID(int pid)
{
string pidfile = GetPIDFile();
if (!File.Exists(pidfile)) return;

List<int> pidEntries = ReadServerPIDs();
pidEntries.Remove(pid);

File.Delete(pidfile);
foreach (int pidEntry in pidEntries)
{
SaveServerPID(pidEntry);
}
}
}
}
6 changes: 6 additions & 0 deletions Samples~/ChatBot/ChatBot.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,11 @@ void Update()
lastBubbleOutsideFOV = -1;
}
}

public void ExitGame()
{
Debug.Log("Exit button clicked");
Application.Quit();
}
}
}
14 changes: 0 additions & 14 deletions Samples~/ChatBot/ExitButton.cs

This file was deleted.

Loading
Loading