diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index e06aeb0..93684e0 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -14,20 +14,10 @@ #include #include #include -#include - +#include "yyjson.hpp" namespace duckdb { - struct OpenPromptData: FunctionData { - unique_ptr Copy() const { - throw std::runtime_error("OpenPromptData::Copy"); - }; - bool Equals(const FunctionData &other) const { - throw std::runtime_error("OpenPromptData::Equals"); - }; - }; - -// Helper function to parse URL and setup client + static std::pair SetupHttpClient(const std::string &url) { std::string scheme, domain, path; size_t pos = url.find("://"); @@ -46,7 +36,6 @@ static std::pair SetupHttpClient(co path = "/"; } - // Create client and set a reasonable timeout (e.g., 10 seconds) duckdb_httplib_openssl::Client client(domain.c_str()); client.set_read_timeout(10, 0); // 10 seconds client.set_follow_location(true); // Follow redirects @@ -98,184 +87,167 @@ static void HandleHttpError(const duckdb_httplib_openssl::Result &res, const std throw std::runtime_error(err_message); } - -// Open Prompt -// Global settings - static std::string api_url = "http://localhost:11434/v1/chat/completions"; - static std::string api_token; // Store your API token here - static std::string model_name = "qwen2.5:0.5b"; // Default model - static std::mutex settings_mutex; - - // Function to set API token - void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), - [&](string_t token) { - try { - auto _token = token.GetData(); - if (token.Empty()) { - throw std::invalid_argument("API token cannot be empty."); - } - ClientConfig::GetConfig(state.GetContext()).SetUserVariable( - "openprompt_api_token", - Value::CreateValue(token.GetString())); - return StringVector::AddString(result, string("token : ") + string(_token, token.GetSize())); - } catch (std::exception &e) { - string_t res(e.what()); - res.Finalize(); - return res; - } - }); +// Settings management +static std::string GetConfigValue(ClientContext &context, const string &var_name, const string &default_value) { + Value value; + auto &config = ClientConfig::GetConfig(context); + if (!config.GetUserVariable(var_name, value) || value.IsNull()) { + return default_value; } + return value.ToString(); +} - // Function to set API URL - void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), - [&](string_t token) { +static void SetConfigValue(DataChunk &args, ExpressionState &state, Vector &result, + const string &var_name, const string &value_type) { + UnaryExecutor::Execute(args.data[0], result, args.size(), + [&](string_t value) { try { - auto _token = token.GetData(); - if (token.Empty()) { - throw std::invalid_argument("API token cannot be empty."); + if (value == "" || value.GetSize() == 0) { + throw std::invalid_argument(value_type + " cannot be empty."); } + ClientConfig::GetConfig(state.GetContext()).SetUserVariable( - "openprompt_api_url", - Value::CreateValue(token.GetString())); - return StringVector::AddString(result, string("url : ") + string(_token, token.GetSize())); + var_name, + Value::CreateValue(value.GetString()) + ); + return StringVector::AddString(result, value_type + " set to: " + value.GetString()); } catch (std::exception &e) { - string_t res(e.what()); - res.Finalize(); - return res; + return StringVector::AddString(result, "Failed to set " + value_type + ": " + e.what()); } }); - } - - // Function to set model name - void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), - [&](string_t token) { - try { - auto _token = token.GetData(); - if (token.Empty()) { - throw std::invalid_argument("API token cannot be empty."); - } - ClientConfig::GetConfig(state.GetContext()).SetUserVariable( - "openprompt_model_name", - Value::CreateValue(token.GetString())); - return StringVector::AddString(result, string("name : ") + string(_token, token.GetSize())); - } catch (std::exception &e) { - string_t res(e.what()); - res.Finalize(); - return res; - } - }); - } +} - // Retrieve the API URL from the stored settings - static std::string GetApiUrl() { - std::lock_guard guard(settings_mutex); - return api_url.empty() ? "http://localhost:11434/v1/chat/completions" : api_url; - } +static void SetApiToken(DataChunk &args, ExpressionState &state, Vector &result) { + SetConfigValue(args, state, result, "openprompt_api_token", "API token"); +} - // Retrieve the API token from the stored settings - static std::string GetApiToken() { - std::lock_guard guard(settings_mutex); - return api_token; - } +static void SetApiUrl(DataChunk &args, ExpressionState &state, Vector &result) { + SetConfigValue(args, state, result, "openprompt_api_url", "API URL"); +} - // Retrieve the model name from the stored settings - static std::string GetModelName() { - std::lock_guard guard(settings_mutex); - return model_name.empty() ? "qwen2.5:0.5b" : model_name; - } +static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result) { + SetConfigValue(args, state, result, "openprompt_model_name", "Model name"); +} - template a assert_null(a val) { - if (val == nullptr) { - throw std::runtime_error("Failed to parse the first message content in the API response."); - } - return val; - } -// Open Prompt Function +// Main Function static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.data.size() >= 1); // At least prompt required + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t user_prompt) { - auto &conf = ClientConfig::GetConfig(state.GetContext()); - Value api_url; - Value api_token; - Value model_name; - conf.GetUserVariable("openprompt_api_url", api_url); - conf.GetUserVariable("openprompt_api_token", api_token); - conf.GetUserVariable("openprompt_model_name", model_name); - - // Manually construct the JSON body as a string. TODO use json parser from extension. + auto &context = state.GetContext(); + + // Get configuration with defaults + std::string api_url = GetConfigValue(context, "openprompt_api_url", + "http://localhost:11434/v1/chat/completions"); + std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); + std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); + + // Override model if provided as second argument + if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) { + model_name = args.data[1].GetValue(0).ToString(); + } + std::string request_body = "{"; - request_body += "\"model\":\"" + model_name.ToString() + "\","; + request_body += "\"model\":\"" + model_name + "\","; request_body += "\"messages\":["; request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; request_body += "]}"; try { - // Make the POST request - auto client_and_path = SetupHttpClient(api_url.ToString()); + auto client_and_path = SetupHttpClient(api_url); auto &client = client_and_path.first; auto &path = client_and_path.second; - // Setup headers - duckdb_httplib_openssl::Headers header_map; - header_map.emplace("Content-Type", "application/json"); - if (!api_token.ToString().empty()) { - header_map.emplace("Authorization", "Bearer " + api_token.ToString()); + duckdb_httplib_openssl::Headers headers; + headers.emplace("Content-Type", "application/json"); + if (!api_token.empty()) { + headers.emplace("Authorization", "Bearer " + api_token); + } + + auto res = client.Post(path.c_str(), headers, request_body, "application/json"); + + if (!res) { + HandleHttpError(res, "POST"); + } + + if (res->status != 200) { + throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason); } - // Send the request - auto res = client.Post(path.c_str(), header_map, request_body, "application/json"); - if (res && res->status == 200) { - // Extract the first choice's message content from the response - std::string response_body = res->body; - unique_ptr doc( - nullptr, &duckdb_yyjson::yyjson_doc_free - ); - doc.reset(assert_null( - duckdb_yyjson::yyjson_read(response_body.c_str(), response_body.length(), 0) - )); - auto root = assert_null(duckdb_yyjson::yyjson_doc_get_root(doc.get())); - auto choices = assert_null(duckdb_yyjson::yyjson_obj_get(root, "choices")); - auto choices_0 = assert_null(duckdb_yyjson::yyjson_arr_get_first(choices)); - auto message = assert_null(duckdb_yyjson::yyjson_obj_get(choices_0, "message")); - auto content = assert_null(duckdb_yyjson::yyjson_obj_get(message, "content")); - auto c_content = assert_null(duckdb_yyjson::yyjson_get_str(content)); - return StringVector::AddString(result, c_content); + try { + unique_ptr doc( + duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0), + &duckdb_yyjson::yyjson_doc_free + ); + + if (!doc) { + throw std::runtime_error("Failed to parse JSON response"); + } + + auto root = duckdb_yyjson::yyjson_doc_get_root(doc.get()); + if (!root) { + throw std::runtime_error("Invalid JSON response: no root object"); + } + + auto choices = duckdb_yyjson::yyjson_obj_get(root, "choices"); + if (!choices || !duckdb_yyjson::yyjson_is_arr(choices)) { + throw std::runtime_error("Invalid response format: missing choices array"); + } + + auto first_choice = duckdb_yyjson::yyjson_arr_get_first(choices); + if (!first_choice) { + throw std::runtime_error("Empty choices array in response"); + } + + auto message = duckdb_yyjson::yyjson_obj_get(first_choice, "message"); + if (!message) { + throw std::runtime_error("Missing message in response"); + } + + auto content = duckdb_yyjson::yyjson_obj_get(message, "content"); + if (!content) { + throw std::runtime_error("Missing content in response"); + } + + auto content_str = duckdb_yyjson::yyjson_get_str(content); + if (!content_str) { + throw std::runtime_error("Invalid content in response"); + } + + return StringVector::AddString(result, content_str); + } catch (std::exception &e) { + throw std::runtime_error("Failed to parse response: " + std::string(e.what())); } - throw std::runtime_error("HTTP POST error: " + std::to_string(res->status) + " - " + res->reason); } catch (std::exception &e) { - // In case of any error, return the original input text to avoid disruption - return StringVector::AddString(result, e.what()); + // Log error and return error message + return StringVector::AddString(result, "Error: " + std::string(e.what())); } }); } - +// LoadInternal function static void LoadInternal(DatabaseInstance &instance) { - // Register open_prompt function with two arguments: prompt and model ScalarFunctionSet open_prompt("open_prompt"); + + // Register with both single and two-argument variants open_prompt.AddFunction(ScalarFunction( {LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction)); + open_prompt.AddFunction(ScalarFunction( + {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, OpenPromptRequestFunction)); + ExtensionUtil::RegisterFunction(instance, open_prompt); - // Other set_* functions remain the same as before + // Register setting functions ExtensionUtil::RegisterFunction(instance, ScalarFunction( - "set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, - SetApiToken)); - + "set_api_token", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiToken)); ExtensionUtil::RegisterFunction(instance, ScalarFunction( - "set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, - SetApiUrl)); - + "set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl)); ExtensionUtil::RegisterFunction(instance, ScalarFunction( - "set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName - )); + "set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName)); } - void OpenPromptExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } @@ -292,7 +264,6 @@ std::string OpenPromptExtension::Version() const { #endif } - } // namespace duckdb extern "C" { @@ -309,4 +280,3 @@ DUCKDB_EXTENSION_API const char *open_prompt_version() { #ifndef DUCKDB_EXTENSION_MAIN #error DUCKDB_EXTENSION_MAIN not defined #endif -