diff --git a/src/open_prompt_extension.cpp b/src/open_prompt_extension.cpp index 93684e0..9791ff5 100644 --- a/src/open_prompt_extension.cpp +++ b/src/open_prompt_extension.cpp @@ -129,6 +129,10 @@ static void SetModelName(DataChunk &args, ExpressionState &state, Vector &result SetConfigValue(args, state, result, "openprompt_model_name", "Model name"); } +static void SetJsonSchema(DataChunk &args, ExpressionState &state, Vector &result) { + SetConfigValue(args, state, result, "openprompt_json_schema", "JSON Schema"); +} + // Main Function static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, Vector &result) { D_ASSERT(args.data.size() >= 1); // At least prompt required @@ -142,6 +146,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V "http://localhost:11434/v1/chat/completions"); std::string api_token = GetConfigValue(context, "openprompt_api_token", ""); std::string model_name = GetConfigValue(context, "openprompt_model_name", "qwen2.5:0.5b"); + std::string json_schema = GetConfigValue(context, "openprompt_json_schema", ""); // Override model if provided as second argument if (args.data.size() > 1 && !args.data[1].GetValue(0).IsNull()) { @@ -151,7 +156,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V std::string request_body = "{"; request_body += "\"model\":\"" + model_name + "\","; request_body += "\"messages\":["; - request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; + if (!json_schema.empty()) { + request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant. Summarize and Output JSON format (without any omissions): " + json_schema + "\"},"; + } else { + request_body += "{\"role\":\"system\",\"content\":\"You are a helpful assistant.\"},"; + } request_body += "{\"role\":\"user\",\"content\":\"" + user_prompt.GetString() + "\"}"; request_body += "]}"; @@ -167,11 +176,11 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V } auto res = client.Post(path.c_str(), headers, request_body, "application/json"); - + if (!res) { HandleHttpError(res, "POST"); } - + if (res->status != 200) { throw std::runtime_error("HTTP error " + std::to_string(res->status) + ": " + res->reason); } @@ -181,7 +190,7 @@ static void OpenPromptRequestFunction(DataChunk &args, ExpressionState &state, V duckdb_yyjson::yyjson_read(res->body.c_str(), res->body.length(), 0), &duckdb_yyjson::yyjson_doc_free ); - + if (!doc) { throw std::runtime_error("Failed to parse JSON response"); } @@ -246,6 +255,8 @@ static void LoadInternal(DatabaseInstance &instance) { "set_api_url", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetApiUrl)); ExtensionUtil::RegisterFunction(instance, ScalarFunction( "set_model_name", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetModelName)); + ExtensionUtil::RegisterFunction(instance, ScalarFunction( + "set_json_schema", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetJsonSchema)); } void OpenPromptExtension::Load(DuckDB &db) {