diff --git a/.gitignore b/.gitignore index 62b6b8b1ab250..5596e43219ccd 100644 --- a/.gitignore +++ b/.gitignore @@ -92,3 +92,4 @@ examples/jeopardy/results.txt poetry.lock poetry.toml nppBackup +functionary-test diff --git a/Makefile b/Makefile index f03faf6eda0fb..740e96b7d9353 100644 --- a/Makefile +++ b/Makefile @@ -719,7 +719,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) +server: examples/server/server.cpp examples/server/oai.hpp examples/server/utils.hpp examples/server/httplib.h examples/server/json.hpp examples/server/functionary.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h examples/llava/llava.h examples/llava/llava.cpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h %.hpp $< examples/llava/clip.cpp,$^) $(call GET_OBJ_FILE, $<) $(call GET_OBJ_FILE, examples/llava/clip.cpp) -o $@ $(LDFLAGS) $(LWINSOCK2) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index cc13b2d630652..1eb9ccedae4ac 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,7 +1,7 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp httplib.h) +add_executable(${TARGET} server.cpp oai.hpp utils.hpp json.hpp functionary.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ diff --git a/examples/server/functionary-test.cpp b/examples/server/functionary-test.cpp new file mode 100644 index 0000000000000..15100860a4a79 --- /dev/null +++ b/examples/server/functionary-test.cpp @@ -0,0 +1,110 @@ +#include +#include +#include +#include + +#include "json.hpp" +#include "functionary.hpp" + +using json = nlohmann::json; + +/** + * A simple test program that allow testing functionary.hpp without using server. + * TODO: how to add this test to CI? + * + * Compile command: clear && g++ functionary-test.cpp -o functionary-test && ./functionary-test + */ + +std::string test_oai_input_json = R"( +{ + "tools": [ + { + "type": "function", + "function": { + "name": "get_n_day_weather_forecast", + "description": "Get an N-day weather forecast", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The temperature unit to use. Infer this from the users location." + }, + "num_days": { + "type": "integer", + "description": "The number of days to forecast" + } + }, + "required": ["location", "format", "num_days"] + } + } + } + ], + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "What is the weather like in Boston?" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "type": "function", + "id":"get_car_price", + "function": { + "arguments": "{\"car_name\": \"Song\"}", + "name": "get_car_price" + } + }, + { + "type": "function", + "id":"get_car_price", + "function": { + "arguments": "{\"car_name\": \"Tang\"}", + "name": "get_car_price" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "get_car_price", + "name": "get_car_price", + "content": "{\"price\": {\"price\": \"$25000\"}}" + }, + { + "role": "tool", + "tool_call_id": "get_car_price", + "name": "get_car_price", + "content": "{\"price\": {\"price\": \"$20000\"}}" + } + ] +} +)"; + + +std::string test_response = R"(get_car_price +<|content|>{"car_name": "Song"} +<|from|>assistant +<|recipient|>get_car_price +<|content|>{"car_name": "Tang"}<|stop|>)"; + +int main() { + auto test_oai_input = json::parse(test_oai_input_json); + auto prompt = llama_functionary::convert_oai_to_prompt(test_oai_input, true); + std::cout << "\n" << prompt << "\n"; + + std::cout << "\n" << llama_functionary::convert_response_to_oai_choices(test_response) << "\n"; + + return 0; +} diff --git a/examples/server/functionary.hpp b/examples/server/functionary.hpp new file mode 100644 index 0000000000000..95e771d4069cb --- /dev/null +++ b/examples/server/functionary.hpp @@ -0,0 +1,317 @@ +#include +#include +#include +#include + +#include "json.hpp" + +using json = nlohmann::json; + +/** + * Integration with functionary model: https://github.com/MeetKai/functionary + * Based on my research: https://github.com/ggerganov/llama.cpp/issues/5588 + * + * A typical flow is: + * - Step 1: user send request to model + * - Step 2: model send back a response to user + * - Step 3: model send back another response to function (optional) + * - Step 4: function send its returned value to model + * - Step 5: finally, model send final response back to user + */ + +#define FUNCTIONARY_FN_PROMPT "// Supported function definitions that should be called when necessary." +#define FUNCTIONARY_RECIP_ALL "all" +#define FUNCTIONARY_RECIP_NONE "no-tool-call" + +namespace llama_functionary { + +template +static T json_value(const json &body, const std::string &key, const T &default_value) +{ + // Fallback null to default value + return body.contains(key) && !body.at(key).is_null() + ? body.value(key, default_value) + : default_value; +} + +inline std::string str_replace(const std::string & original, const std::string & search, const std::string & replacement) { + size_t pos = original.find(search); + if (pos != std::string::npos) { + std::string result = original; + result.replace(pos, search.length(), replacement); + return result; + } + return original; +} + +inline std::vector str_split(std::string str, const std::string & delimiter) { + size_t pos = 0; + std::string token; + std::vector output; + while ((pos = str.find(delimiter)) != std::string::npos) { + token = str.substr(0, pos); + output.push_back(token); + str.erase(0, pos + delimiter.length()); + } + output.push_back(str); // the rest + return output; +} + +typedef struct message { + std::string from; // can be "system", "user", "assistant" or name of function + std::string recipient = FUNCTIONARY_RECIP_ALL; + std::string content; + bool has_stop = false; + message() {} + message(json oai_json) { + from = json_value(oai_json, "role", std::string("")); + if (from == "tool") { + // response from function + from = json_value(oai_json, "tool_call_id", std::string("")); + } + content = json_value(oai_json, "content", std::string("")); + } + message(std::string & prompt) { + std::istringstream iss(prompt); + std::string line; + std::stringstream ss; + int i = 0; // line number + while (std::getline(iss, line)) { + if (i == 0) { + from = str_replace(line, "<|from|>", ""); + } else if (i == 1) { + recipient = str_replace(line, "<|recipient|>", ""); + } else if (i == 2) { + ss << str_replace(line, "<|content|>", ""); + } else { + ss << "\n" << line; + } + ++i; + } + has_stop = ss.str().find("<|stop|>") != std::string::npos; + content = str_replace(ss.str(), "<|stop|>", ""); + } + std::string to_prompt() { + std::stringstream ss; + ss << "<|from|>" << from << "\n"; + ss << "<|recipient|>" << recipient << "\n"; + ss << "<|content|>" << content; + if (has_stop) { + ss << "<|stop|>"; + } + ss << "\n"; + return ss.str(); + } +} message; + +typedef struct function_param { + std::string name; + // type can be "string", "boolean", "number" (typescript types) + // we do not support array for now + std::string type; + std::string desc; + std::vector allowed_values; // dynamic types + bool required; + function_param(std::string param_name, json & oai_json) { + name = param_name; + type = json_value(oai_json, "type", std::string()); + desc = json_value(oai_json, "description", std::string()); + if (oai_json.count("enum")) { + allowed_values = oai_json["enum"]; + } + } +} function_param; + +typedef struct function_def { + std::string name; + std::string desc; + std::vector params; + // parameters.type must always be "object" + function_def(json & oai_json) { + std::string type = json_value(oai_json, "type", std::string()); + if (type != "function") { + throw std::runtime_error("Only tool type \"function\" is supported"); + } + // function + json inner_json = json_value(oai_json, "function", json::object()); + name = json_value(inner_json, "name", std::string()); + desc = json_value(inner_json, "description", std::string()); + // function.parameters + json parameters = json_value(inner_json, "parameters", json::object()); + std::string param_type = json_value(parameters, "type", std::string()); + if (param_type != "object") { + throw std::runtime_error("Only parameters type \"object\" is supported"); + } + // function.parameters.properties + json properties = json_value(parameters, "properties", json::object()); + for (auto& it : properties.items()) { + std::string curr_prop = it.key(); + json data = json_value(properties, curr_prop, json::object()); + function_param param(curr_prop, data); + params.push_back(param); + } + // TODO: add required !!!!!!!!!!!!!! + } +} function_def; + +// convert OAI type to typescript +inline std::string oai_type_to_ts(std::string & type, std::vector & allowed_values) { + if (!allowed_values.empty()) { + std::stringstream ss; + for (size_t i = 0; i < allowed_values.size(); ++i) { + ss << allowed_values[i]; + if (i < allowed_values.size() - 1) { + ss << " | "; + } + } + return ss.str(); + } + // non-enum types + if (type == "string" || type == "number" || type == "boolean") { + return type; // natively supported + } else if (type == "bool") { + return "boolean"; + } else if (type == "integer" || type == "float" || type == "double") { + return "number"; + } else { + throw std::runtime_error("Unsupported type: " + type); + } +} + +inline std::string serialize_function(function_def & fn) { + std::stringstream ss; + if (fn.name.empty()) { + throw std::runtime_error("Function name is empty"); + } + if (!fn.desc.empty()) { + // TODO: what if the desc has multiple lines? + ss << "// " << fn.desc << "\n"; + } + ss << "type " << fn.name << " = (_: {\n"; + for (auto & param : fn.params) { + if (!param.desc.empty()) { + ss << "// " << param.desc << "\n"; + } + ss << param.name << ": " << oai_type_to_ts(param.type, param.allowed_values) << ",\n"; + } + // only support "any" return type for now + ss << "}) => any;\n\n"; + return ss.str(); +} + +/////////////////////////////////////////// +// Main hooks, to be called in oai.hpp + +inline std::string convert_oai_to_prompt(const json & body, bool add_ass, bool allow_tool = true) { + std::stringstream ss; + // convert function definitions + std::vector tools = json_value(body, "tools", json::array()); + if (!tools.empty()) { + std::stringstream ss_fn; + ss_fn << FUNCTIONARY_FN_PROMPT << "\n"; + ss_fn << "namespace functions {" << "\n\n"; + for (auto & tool : tools) { + function_def fn(tool); + ss_fn << serialize_function(fn); + } + ss_fn << "} // namespace functions"; + // construct the message + message fn_def_msg; + fn_def_msg.from = "system"; + fn_def_msg.recipient = FUNCTIONARY_RECIP_ALL; + fn_def_msg.content = ss_fn.str(); + ss << fn_def_msg.to_prompt(); + } + // convert history + std::vector messages = json_value(body, "messages", json::array()); + for (auto & msg_json : messages) { + // TODO: how to detect where to put "<|stop|>"? + if (msg_json.count("tool_calls")) { + // assistant request to function call, now re-passed to history + std::vector tool_calls = msg_json["tool_calls"]; + for (size_t i = 0; i < tool_calls.size(); i++) { + auto & tc = tool_calls[i]; + message msg; + msg.from = tc["function"]["name"]; + msg.content = tc["function"]["arguments"]; + msg.has_stop = i == tool_calls.size() - 1; // last msg + ss << msg.to_prompt(); + } + } else { + // all other types of message + message msg(msg_json); + msg.has_stop = msg.from == "assistant"; // add stop if this is single text message from assistant (not contains tool_calls) + ss << msg.to_prompt(); + } + } + // add trailing assistant prompt + if (add_ass) { + ss << "<|from|>assistant\n<|recipient|>"; + if (!allow_tool) { + ss << FUNCTIONARY_RECIP_NONE; + } + } + return ss.str(); +} + +inline json convert_response_to_oai_choices(const std::string & content) { + std::string input_full = content; + std::string text_response; + json tool_calls = json::array(); + // parse all turns + std::vector turns = str_split(input_full, "<|from|>"); + if (!turns.empty()) { + // first turn may not have the assistant tag (because it was part of the prompt added by "add_ass"), we will put it back to parse the message + // the "<|from|>" will be added later + if (turns[0].find("<|recipient|>") == std::string::npos) { + turns[0] = "assistant\n<|recipient|>" + turns[0]; + } + } + for (auto & turn : turns) { + std::string turn_full = "<|from|>" + turn; + message msg(turn_full); + if (msg.from != "assistant") { + continue; // this case should never happen + } + if (msg.recipient != FUNCTIONARY_RECIP_ALL && msg.recipient != FUNCTIONARY_RECIP_NONE) { + // the assistant decide to call a tool (step 3) + tool_calls.push_back(json{ + {"id", msg.recipient}, // TODO: maybe generate a random part? + {"type", "function"}, + {"function", json{ + {"name", msg.recipient}, + {"arguments", msg.content}, + }}, + }); + } else { + // the assistant just want to say something (step 2) + text_response = msg.content; + } + } + // build final response + json choices = json::array(); + // TODO: technically, functionary can reponse both text + tool_call in one shot. But for some reasons, the original implementation of OpenAI only return only one, not both. + if (tool_calls.size() > 0) { + choices.push_back(json{ + {"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", tool_calls}, + }}, + }); + } else { + choices.push_back(json{ + {"index", 0}, + {"finish_reason", "stop"}, + {"message", json{ + {"role", "assistant"}, + {"content", text_response}, + }}, + }); + } + return choices; +} + +} // namespace llama_functionary diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index ff4ad69943552..0cc06c2f114c6 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -9,6 +9,7 @@ #include "json.hpp" #include "utils.hpp" +#include "functionary.hpp" #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" @@ -17,7 +18,8 @@ using json = nlohmann::json; inline static json oaicompat_completion_params_parse( const struct llama_model * model, const json &body, /* openai api json semantics */ - const std::string &chat_template) + const std::string &chat_template, + bool enable_tool_calls) { json llama_params; @@ -32,7 +34,9 @@ inline static json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chat(model, chat_template, body["messages"]); + llama_params["prompt"] = enable_tool_calls + ? llama_functionary::convert_oai_to_prompt(body, true) + : format_chat(model, chat_template, body["messages"]); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); @@ -63,13 +67,19 @@ inline static json oaicompat_completion_params_parse( llama_params["stop"] = json_value(body, "stop", json::array()); } - // Ensure there is ChatML-specific end sequence among stop words - llama_params["stop"].push_back("<|im_end|>"); + llama_params["stop"].push_back(enable_tool_calls + ? "<|stop|>" // functionary-specific: this model uses "<|stop|>" instead of "" + : "<|im_end|>" // Ensure there is ChatML-specific end sequence among stop words + ); return llama_params; } -inline static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false) +inline static json format_final_response_oaicompat( + const json &request, + const task_result &response, + bool streaming, + bool enable_tool_calls) { json result = response.result_json; @@ -84,14 +94,20 @@ inline static json format_final_response_oaicompat(const json &request, const ta finish_reason = "stop"; } - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); + json choices; + + if (enable_tool_calls) { + choices = llama_functionary::convert_response_to_oai_choices(content); + } else { + choices = streaming + ? json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()}}}) + : json::array({json{{"finish_reason", finish_reason}, + {"index", 0}, + {"message", json{{"content", content}, + {"role", "assistant"}}}}}); + } std::time_t t = std::time(0); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 524d0ada33ab0..c7d9a8ef21940 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -43,6 +43,7 @@ struct server_params int32_t read_timeout = 600; int32_t write_timeout = 600; bool slots_endpoint = true; + bool enable_tool_calls = false; }; bool server_verbose = false; @@ -2772,7 +2773,14 @@ int main(int argc, char **argv) LOG_INFO("model loaded", {}); } - if (sparams.chat_template.empty()) { // custom chat template is not supplied + // Check tool_call ability + sparams.enable_tool_calls = check_model_support_tool_calls(llama.model); + if (sparams.enable_tool_calls) { + LOG_INFO("Current model supports functionary tool_calls", {}); + } + + // custom chat template is not supplied + if (sparams.chat_template.empty() && !sparams.enable_tool_calls) { // check if the template comes with the model is supported by us llama.validate_model_chat_template(sparams); } @@ -2948,7 +2956,9 @@ int main(int argc, char **argv) if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template); + json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template, sparams.enable_tool_calls); + + // TODO: "enable_tool_calls" cannot be used with "stream" mode const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); @@ -2959,7 +2969,7 @@ int main(int argc, char **argv) task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { - json oaicompat_result = format_final_response_oaicompat(data, result); + json oaicompat_result = format_final_response_oaicompat(data, result, false, sparams.enable_tool_calls); res.set_content(oaicompat_result.dump(-1, ' ', false, json::error_handler_t::replace), diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 88545eb6931d0..2c0bb18b69950 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -211,6 +211,20 @@ inline std::string format_chat(const struct llama_model * model, const std::stri return formatted_chat; } +// Detect if the model supports tool_calls +inline bool check_model_support_tool_calls(const struct llama_model * model) { + std::vector model_template(2048, 0); + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + if (res < 0) { + return false; // no template in model + } else { + model_template.resize(res); + std::string tmpl(model_template.data(), model_template.size()); + return tmpl.find("<|recipient|>") != std::string::npos; + } +} + // // work queue utils //