diff --git a/changelog/unreleased/kong/update-ai-proxy-telemetry.yml b/changelog/unreleased/kong/update-ai-proxy-telemetry.yml new file mode 100644 index 000000000000..e4ac98afa760 --- /dev/null +++ b/changelog/unreleased/kong/update-ai-proxy-telemetry.yml @@ -0,0 +1,3 @@ +message: Update telemetry collection for AI Plugins to allow multiple instances data to be set for the same request. +type: bugfix +scope: Core diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b38fd8d6c848..69c89d9d5c51 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -10,14 +10,22 @@ local parse_url = require("socket.url").parse local log_entry_keys = { REQUEST_BODY = "ai.payload.request", - RESPONSE_BODY = "ai.payload.response", - - TOKENS_CONTAINER = "ai.usage", - PROCESSING_TIME = "ai.usage.processing_time", - - REQUEST_MODEL = "ai.meta.request_model", - RESPONSE_MODEL = "ai.meta.response_model", - PROVIDER_NAME = "ai.meta.provider_name", + RESPONSE_BODY = "payload.response", + + TOKENS_CONTAINER = "usage", + META_CONTAINER = "meta", + + -- meta keys + REQUEST_MODEL = "request_model", + RESPONSE_MODEL = "response_model", + PROVIDER_NAME = "provider_name", + PLUGIN_ID = "plugin_id", + + -- usage keys + PROCESSING_TIME = "processing_time", + PROMPT_TOKEN = "prompt_token", + COMPLETION_TOKEN = "completion_token", + TOTAL_TOKENS = "total_tokens", } local openai_override = os.getenv("OPENAI_TEST_PORT") @@ -27,6 +35,33 @@ _M.streaming_has_token_counts = { ["llama2"] = true, } +--- Splits a table key into nested tables. +-- Each part of the key separated by dots represents a nested table. +-- @param obj The table to split keys for. +-- @return A nested table structure representing the split keys. +local function split_table_key(obj) + local result = {} + + for key, value in pairs(obj) do + local keys = {} + for k in key:gmatch("[^.]+") do + table.insert(keys, k) + end + + local currentTable = result + for i, k in ipairs(keys) do + if i < #keys then + currentTable[k] = currentTable[k] or {} + currentTable = currentTable[k] + else + currentTable[k] = value + end + end + end + + return result +end + _M.upstream_url_format = { openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"), anthropic = "https://api.anthropic.com:443", @@ -247,61 +282,96 @@ function _M.pre_request(conf, request_table) end function _M.post_request(conf, response_object) - local err + local body_string, err if type(response_object) == "string" then -- set raw string body first, then decode - if conf.logging and conf.logging.log_payloads then - kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_object) - end + body_string = response_object + -- unpack the original response object for getting token and meta info response_object, err = cjson.decode(response_object) if err then - return nil, "failed to decode response from JSON" + return nil, "failed to decode LLM response from JSON" end else - -- this has come from another AI subsystem, and contains "response" field - if conf.logging and conf.logging.log_payloads then - kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_object.response or "ERROR__NOT_SET") - end + -- this has come from another AI subsystem, is already formatted, and contains "response" field + body_string = response_object.response or "ERROR__NOT_SET" end -- analytics and logging if conf.logging and conf.logging.log_statistics then + local provider_name = conf.model.provider + -- check if we already have analytics in this context local request_analytics = kong.ctx.shared.analytics + -- create a new try context + local current_try = { + [log_entry_keys.META_CONTAINER] = {}, + [log_entry_keys.TOKENS_CONTAINER] = {}, + } + -- create a new structure if not if not request_analytics then - request_analytics = { - prompt_tokens = 0, - completion_tokens = 0, - total_tokens = 0, + request_analytics = {} + end + + -- check if we already have analytics for this provider + local request_analytics_provider = request_analytics[provider_name] + + -- create a new structure if not + if not request_analytics_provider then + request_analytics_provider = { + request_prompt_tokens = 0, + request_completion_tokens = 0, + request_total_tokens = 0, + number_of_instances = 0, + instances = {}, } end - -- this captures the openai-format usage stats from the transformed response body + -- Set the model, response, and provider names in the current try context + current_try[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name + current_try[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name + current_try[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name + current_try[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id + + -- Capture openai-format usage stats from the transformed response body if response_object.usage then if response_object.usage.prompt_tokens then - request_analytics.prompt_tokens = (request_analytics.prompt_tokens + response_object.usage.prompt_tokens) + request_analytics_provider.request_prompt_tokens = (request_analytics_provider.request_prompt_tokens + response_object.usage.prompt_tokens) + current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] = response_object.usage.prompt_tokens end if response_object.usage.completion_tokens then - request_analytics.completion_tokens = (request_analytics.completion_tokens + response_object.usage.completion_tokens) + request_analytics_provider.request_completion_tokens = (request_analytics_provider.request_completion_tokens + response_object.usage.completion_tokens) + current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] = response_object.usage.completion_tokens end if response_object.usage.total_tokens then - request_analytics.total_tokens = (request_analytics.total_tokens + response_object.usage.total_tokens) + request_analytics_provider.request_total_tokens = (request_analytics_provider.request_total_tokens + response_object.usage.total_tokens) + current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] = response_object.usage.total_tokens end end - -- update context with changed values - kong.ctx.shared.analytics = request_analytics - for k, v in pairs(request_analytics) do - kong.log.set_serialize_value(fmt("%s.%s", log_entry_keys.TOKENS_CONTAINER, k), v) + -- Log response body if logging payloads is enabled + if conf.logging and conf.logging.log_payloads then + current_try[log_entry_keys.RESPONSE_BODY] = body_string end - kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name) - kong.log.set_serialize_value(log_entry_keys.RESPONSE_MODEL, response_object.model or conf.model.name) - kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider) + -- Increment the number of instances + request_analytics_provider.number_of_instances = request_analytics_provider.number_of_instances + 1 + + -- Get the current try count + local try_count = request_analytics_provider.number_of_instances + + -- Store the split key data in instances + request_analytics_provider.instances[try_count] = split_table_key(current_try) + + -- Update context with changed values + request_analytics[provider_name] = request_analytics_provider + kong.ctx.shared.analytics = request_analytics + + -- Log analytics data + kong.log.set_serialize_value(fmt("%s.%s", "ai", provider_name), request_analytics_provider) end return nil diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 64ea7c17dd22..b5886683fcc9 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -82,33 +82,69 @@ end function _M:body_filter(conf) - if kong.ctx.shared.skip_response_transformer then - return - end - - if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then + -- if body_filter is called twice, then return + if kong.ctx.plugin.body_called then return end - -- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error) - - -- all errors MUST be checked and returned in header_filter - -- we should receive a replacement response body from the same thread + if kong.ctx.shared.skip_response_transformer then + local response_body + if kong.ctx.shared.parsed_response then + response_body = kong.ctx.shared.parsed_response + elseif kong.response.get_status() == 200 then + response_body = kong.service.response.get_raw_body() + if not response_body then + kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.", + " Please check AI request transformer plugin response.") + else + local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" + if is_gzip then + response_body = kong_utils.inflate_gzip(response_body) + end + end + end - local original_request = kong.ctx.plugin.parsed_response - local deflated_request = original_request + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) + local route_type = conf.route_type + local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) + + if err then + kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err) + elseif new_response_string then + ai_shared.post_request(conf, new_response_string) + end + end - if deflated_request then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - deflated_request = kong_utils.deflate_gzip(deflated_request) + if not kong.ctx.shared.skip_response_transformer then + if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then + return + end + + -- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error) + + -- all errors MUST be checked and returned in header_filter + -- we should receive a replacement response body from the same thread + + local original_request = kong.ctx.plugin.parsed_response + local deflated_request = original_request + + if deflated_request then + local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" + if is_gzip then + deflated_request = kong_utils.deflate_gzip(deflated_request) + end + + kong.response.set_raw_body(deflated_request) end - kong.response.set_raw_body(deflated_request) + -- call with replacement body, or original body if nothing changed + local _, err = ai_shared.post_request(conf, original_request) + if err then + kong.log.warn("analytics phase failed for request, ", err) + end end - -- call with replacement body, or original body if nothing changed - ai_shared.post_request(conf, original_request) + kong.ctx.plugin.body_called = true end diff --git a/kong/plugins/ai-request-transformer/handler.lua b/kong/plugins/ai-request-transformer/handler.lua index 7efd0e0c72ef..7553877660a5 100644 --- a/kong/plugins/ai-request-transformer/handler.lua +++ b/kong/plugins/ai-request-transformer/handler.lua @@ -44,6 +44,7 @@ function _M:access(conf) -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) + conf.llm.__plugin_id = conf.__plugin_id local ai_driver, err = llm:new(conf.llm, http_opts) if not ai_driver then diff --git a/kong/plugins/ai-response-transformer/handler.lua b/kong/plugins/ai-response-transformer/handler.lua index d4535b37e6d5..7fd4a2900b79 100644 --- a/kong/plugins/ai-response-transformer/handler.lua +++ b/kong/plugins/ai-response-transformer/handler.lua @@ -97,6 +97,7 @@ function _M:access(conf) -- first find the configured LLM interface and driver local http_opts = create_http_opts(conf) + conf.llm.__plugin_id = conf.__plugin_id local ai_driver, err = llm:new(conf.llm, http_opts) if not ai_driver then @@ -116,6 +117,8 @@ function _M:access(conf) res_body = kong_utils.inflate_gzip(res_body) end + kong.ctx.shared.parsed_response = res_body + -- if asked, introspect the request before proxying kong.log.debug("introspecting response with LLM") diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index 8919fbe06524..b9d8c31888d4 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -1,10 +1,63 @@ local helpers = require "spec.helpers" local cjson = require "cjson" local pl_file = require "pl.file" +local pl_stringx = require "pl.stringx" local PLUGIN_NAME = "ai-proxy" local MOCK_PORT = helpers.get_available_port() + +local FILE_LOG_PATH_STATS_ONLY = os.tmpname() +local FILE_LOG_PATH_NO_LOGS = os.tmpname() +local FILE_LOG_PATH_WITH_PAYLOADS = os.tmpname() + + +local function wait_for_json_log_entry(FILE_LOG_PATH) + local json + + assert + .with_timeout(10) + .ignore_exceptions(true) + .eventually(function() + local data = assert(pl_file.read(FILE_LOG_PATH)) + + data = pl_stringx.strip(data) + assert(#data > 0, "log file is empty") + + data = data:match("%b{}") + assert(data, "log file does not contain JSON") + + json = cjson.decode(data) + end) + .has_no_error("log file contains a valid JSON entry") + + return json +end + +local _EXPECTED_CHAT_STATS = { + openai = { + instances = { + { + meta = { + plugin_id = '6e7c40f6-ce96-48e4-a366-d109c169e444', + provider_name = 'openai', + request_model = 'gpt-3.5-turbo', + response_model = 'gpt-3.5-turbo-0613', + }, + usage = { + completion_token = 12, + prompt_token = 25, + total_tokens = 37, + }, + }, + }, + number_of_instances = 1, + request_completion_tokens = 12, + request_prompt_tokens = 25, + request_total_tokens = 37, + }, +} + for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then describe(PLUGIN_NAME .. ": (access) [#" .. strategy .. "]", function() local client @@ -157,9 +210,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }) bp.plugins:insert { name = PLUGIN_NAME, + id = "6e7c40f6-ce96-48e4-a366-d109c169e444", route = { id = chat_good.id }, config = { route_type = "llm/v1/chat", + logging = { + log_payloads = false, + log_statistics = true, + }, auth = { header_name = "Authorization", header_value = "Bearer openai-key", @@ -179,7 +237,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then name = "file-log", route = { id = chat_good.id }, config = { - path = "/dev/stdout", + path = FILE_LOG_PATH_STATS_ONLY, }, } -- @@ -219,7 +277,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then name = "file-log", route = { id = chat_good_no_stats.id }, config = { - path = "/dev/stdout", + path = FILE_LOG_PATH_NO_LOGS, }, } -- @@ -259,7 +317,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then name = "file-log", route = { id = chat_good_log_payloads.id }, config = { - path = "/dev/stdout", + path = FILE_LOG_PATH_WITH_PAYLOADS, }, } -- @@ -517,10 +575,16 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then before_each(function() client = helpers.proxy_client() + os.remove(FILE_LOG_PATH_STATS_ONLY) + os.remove(FILE_LOG_PATH_NO_LOGS) + os.remove(FILE_LOG_PATH_WITH_PAYLOADS) end) after_each(function() if client then client:close() end + os.remove(FILE_LOG_PATH_STATS_ONLY) + os.remove(FILE_LOG_PATH_NO_LOGS) + os.remove(FILE_LOG_PATH_WITH_PAYLOADS) end) describe("openai general", function() @@ -549,7 +613,13 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then role = "assistant", }, json.choices[1].message) - -- TODO TEST THE LOG FILE + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_STATS_ONLY) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + -- test ai-proxy stats + assert.same(_EXPECTED_CHAT_STATS, log_message.ai) end) it("does not log statistics", function() @@ -576,8 +646,14 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then content = "The sum of 1 + 1 is 2.", role = "assistant", }, json.choices[1].message) - - -- TODO TEST THE LOG FILE + + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_NO_LOGS) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + -- test ai-proxy has no stats + assert.same(nil, log_message.ai) end) it("logs payloads", function() @@ -605,7 +681,19 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then role = "assistant", }, json.choices[1].message) - -- TODO TEST THE LOG FILE + local log_message = wait_for_json_log_entry(FILE_LOG_PATH_WITH_PAYLOADS) + assert.same("127.0.0.1", log_message.client_ip) + assert.is_number(log_message.request.size) + assert.is_number(log_message.response.size) + + -- test request bodies + assert.matches('"content": "What is 1 + 1?"', log_message.ai.payload.request, nil, true) + assert.matches('"role": "user"', log_message.ai.payload.request, nil, true) + + -- test response bodies + assert.matches('"content": "The sum of 1 + 1 is 2.",', log_message.ai.openai.instances[1].payload.response, nil, true) + assert.matches('"role": "assistant"', log_message.ai.openai.instances[1].payload.response, nil, true) + assert.matches('"id": "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2"', log_message.ai.openai.instances[1].payload.response, nil, true) end) it("internal_server_error request", function()