Skip to content

Commit

Permalink
refactor: change ai analytics handling
Browse files Browse the repository at this point in the history
* refactor ai analytics code

---------

Co-authored-by: Jack Tysoe <jack.tysoe@konghq.com>
Co-authored-by: Jack Tysoe <91137069+tysoekong@users.noreply.github.com>
Co-authored-by: Joshua Schmid <jaiks@posteo.de>
  • Loading branch information
4 people authored Apr 15, 2024
1 parent 1b920ca commit 0db4b62
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 57 deletions.
3 changes: 3 additions & 0 deletions changelog/unreleased/kong/update-ai-proxy-telemetry.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message: Update telemetry collection for AI Plugins to allow multiple instances data to be set for the same request.
type: bugfix
scope: Core
134 changes: 102 additions & 32 deletions kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@ local parse_url = require("socket.url").parse

local log_entry_keys = {
REQUEST_BODY = "ai.payload.request",
RESPONSE_BODY = "ai.payload.response",

TOKENS_CONTAINER = "ai.usage",
PROCESSING_TIME = "ai.usage.processing_time",

REQUEST_MODEL = "ai.meta.request_model",
RESPONSE_MODEL = "ai.meta.response_model",
PROVIDER_NAME = "ai.meta.provider_name",
RESPONSE_BODY = "payload.response",

TOKENS_CONTAINER = "usage",
META_CONTAINER = "meta",

-- meta keys
REQUEST_MODEL = "request_model",
RESPONSE_MODEL = "response_model",
PROVIDER_NAME = "provider_name",
PLUGIN_ID = "plugin_id",

-- usage keys
PROCESSING_TIME = "processing_time",
PROMPT_TOKEN = "prompt_token",
COMPLETION_TOKEN = "completion_token",
TOTAL_TOKENS = "total_tokens",
}

local openai_override = os.getenv("OPENAI_TEST_PORT")
Expand All @@ -27,6 +35,33 @@ _M.streaming_has_token_counts = {
["llama2"] = true,
}

--- Splits a table key into nested tables.
-- Each part of the key separated by dots represents a nested table.
-- @param obj The table to split keys for.
-- @return A nested table structure representing the split keys.
local function split_table_key(obj)
local result = {}

for key, value in pairs(obj) do
local keys = {}
for k in key:gmatch("[^.]+") do
table.insert(keys, k)
end

local currentTable = result
for i, k in ipairs(keys) do
if i < #keys then
currentTable[k] = currentTable[k] or {}
currentTable = currentTable[k]
else
currentTable[k] = value
end
end
end

return result
end

_M.upstream_url_format = {
openai = fmt("%s://api.openai.com:%s", (openai_override and "http") or "https", (openai_override) or "443"),
anthropic = "https://api.anthropic.com:443",
Expand Down Expand Up @@ -247,61 +282,96 @@ function _M.pre_request(conf, request_table)
end

function _M.post_request(conf, response_object)
local err
local body_string, err

if type(response_object) == "string" then
-- set raw string body first, then decode
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_object)
end
body_string = response_object

-- unpack the original response object for getting token and meta info
response_object, err = cjson.decode(response_object)
if err then
return nil, "failed to decode response from JSON"
return nil, "failed to decode LLM response from JSON"
end
else
-- this has come from another AI subsystem, and contains "response" field
if conf.logging and conf.logging.log_payloads then
kong.log.set_serialize_value(log_entry_keys.RESPONSE_BODY, response_object.response or "ERROR__NOT_SET")
end
-- this has come from another AI subsystem, is already formatted, and contains "response" field
body_string = response_object.response or "ERROR__NOT_SET"
end

-- analytics and logging
if conf.logging and conf.logging.log_statistics then
local provider_name = conf.model.provider

-- check if we already have analytics in this context
local request_analytics = kong.ctx.shared.analytics

-- create a new try context
local current_try = {
[log_entry_keys.META_CONTAINER] = {},
[log_entry_keys.TOKENS_CONTAINER] = {},
}

-- create a new structure if not
if not request_analytics then
request_analytics = {
prompt_tokens = 0,
completion_tokens = 0,
total_tokens = 0,
request_analytics = {}
end

-- check if we already have analytics for this provider
local request_analytics_provider = request_analytics[provider_name]

-- create a new structure if not
if not request_analytics_provider then
request_analytics_provider = {
request_prompt_tokens = 0,
request_completion_tokens = 0,
request_total_tokens = 0,
number_of_instances = 0,
instances = {},
}
end

-- this captures the openai-format usage stats from the transformed response body
-- Set the model, response, and provider names in the current try context
current_try[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name
current_try[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name
current_try[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name
current_try[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id

-- Capture openai-format usage stats from the transformed response body
if response_object.usage then
if response_object.usage.prompt_tokens then
request_analytics.prompt_tokens = (request_analytics.prompt_tokens + response_object.usage.prompt_tokens)
request_analytics_provider.request_prompt_tokens = (request_analytics_provider.request_prompt_tokens + response_object.usage.prompt_tokens)
current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] = response_object.usage.prompt_tokens
end
if response_object.usage.completion_tokens then
request_analytics.completion_tokens = (request_analytics.completion_tokens + response_object.usage.completion_tokens)
request_analytics_provider.request_completion_tokens = (request_analytics_provider.request_completion_tokens + response_object.usage.completion_tokens)
current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] = response_object.usage.completion_tokens
end
if response_object.usage.total_tokens then
request_analytics.total_tokens = (request_analytics.total_tokens + response_object.usage.total_tokens)
request_analytics_provider.request_total_tokens = (request_analytics_provider.request_total_tokens + response_object.usage.total_tokens)
current_try[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] = response_object.usage.total_tokens
end
end

-- update context with changed values
kong.ctx.shared.analytics = request_analytics
for k, v in pairs(request_analytics) do
kong.log.set_serialize_value(fmt("%s.%s", log_entry_keys.TOKENS_CONTAINER, k), v)
-- Log response body if logging payloads is enabled
if conf.logging and conf.logging.log_payloads then
current_try[log_entry_keys.RESPONSE_BODY] = body_string
end

kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name)
kong.log.set_serialize_value(log_entry_keys.RESPONSE_MODEL, response_object.model or conf.model.name)
kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider)
-- Increment the number of instances
request_analytics_provider.number_of_instances = request_analytics_provider.number_of_instances + 1

-- Get the current try count
local try_count = request_analytics_provider.number_of_instances

-- Store the split key data in instances
request_analytics_provider.instances[try_count] = split_table_key(current_try)

-- Update context with changed values
request_analytics[provider_name] = request_analytics_provider
kong.ctx.shared.analytics = request_analytics

-- Log analytics data
kong.log.set_serialize_value(fmt("%s.%s", "ai", provider_name), request_analytics_provider)
end

return nil
Expand Down
72 changes: 54 additions & 18 deletions kong/plugins/ai-proxy/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -82,33 +82,69 @@ end


function _M:body_filter(conf)
if kong.ctx.shared.skip_response_transformer then
return
end

if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then
-- if body_filter is called twice, then return
if kong.ctx.plugin.body_called then
return
end

-- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error)

-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread
if kong.ctx.shared.skip_response_transformer then
local response_body
if kong.ctx.shared.parsed_response then
response_body = kong.ctx.shared.parsed_response
elseif kong.response.get_status() == 200 then
response_body = kong.service.response.get_raw_body()
if not response_body then
kong.log.warn("issue when retrieve the response body for analytics in the body filter phase.",
" Please check AI request transformer plugin response.")
else
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
response_body = kong_utils.inflate_gzip(response_body)
end
end
end

local original_request = kong.ctx.plugin.parsed_response
local deflated_request = original_request
local ai_driver = require("kong.llm.drivers." .. conf.model.provider)
local route_type = conf.route_type
local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type)

if err then
kong.log.warn("issue when transforming the response body for analytics in the body filter phase, ", err)
elseif new_response_string then
ai_shared.post_request(conf, new_response_string)
end
end

if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
if not kong.ctx.shared.skip_response_transformer then
if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then
return
end

-- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error)

-- all errors MUST be checked and returned in header_filter
-- we should receive a replacement response body from the same thread

local original_request = kong.ctx.plugin.parsed_response
local deflated_request = original_request

if deflated_request then
local is_gzip = kong.response.get_header("Content-Encoding") == "gzip"
if is_gzip then
deflated_request = kong_utils.deflate_gzip(deflated_request)
end

kong.response.set_raw_body(deflated_request)
end

kong.response.set_raw_body(deflated_request)
-- call with replacement body, or original body if nothing changed
local _, err = ai_shared.post_request(conf, original_request)
if err then
kong.log.warn("analytics phase failed for request, ", err)
end
end

-- call with replacement body, or original body if nothing changed
ai_shared.post_request(conf, original_request)
kong.ctx.plugin.body_called = true
end


Expand Down
1 change: 1 addition & 0 deletions kong/plugins/ai-request-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ function _M:access(conf)

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
conf.llm.__plugin_id = conf.__plugin_id
local ai_driver, err = llm:new(conf.llm, http_opts)

if not ai_driver then
Expand Down
3 changes: 3 additions & 0 deletions kong/plugins/ai-response-transformer/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ function _M:access(conf)

-- first find the configured LLM interface and driver
local http_opts = create_http_opts(conf)
conf.llm.__plugin_id = conf.__plugin_id
local ai_driver, err = llm:new(conf.llm, http_opts)

if not ai_driver then
Expand All @@ -116,6 +117,8 @@ function _M:access(conf)
res_body = kong_utils.inflate_gzip(res_body)
end

kong.ctx.shared.parsed_response = res_body

-- if asked, introspect the request before proxying
kong.log.debug("introspecting response with LLM")

Expand Down
Loading

1 comment on commit 0db4b62

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bazel Build

Docker image available kong/kong:0db4b62421990fb5b0d6cb341fd357ebaaaf50d2
Artifacts available https://github.com/Kong/kong/actions/runs/8691590529

Please sign in to comment.