Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ai-proxy): abstract a base for ai-proxy #11991

Merged
merged 5 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions apisix/plugins/ai-proxy-multi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local ai_proxy = require("apisix.plugins.ai-proxy")
local plugin = require("apisix.plugin")
local base = require("apisix.plugins.ai-proxy.base")

local require = require
local pcall = pcall
local ipairs = ipairs
local unpack = unpack
local type = type

local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
Expand Down Expand Up @@ -190,11 +189,11 @@ local function get_load_balanced_provider(ctx, conf, ups_tab, request_table)
return provider_name, provider_conf
end

ai_proxy.get_model_name = function (...)
local function get_model_name(...)
end


ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
local function proxy_request_to_llm(conf, request_table, ctx)
local ups_tab = {}
local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
if algo == "chash" then
Expand Down Expand Up @@ -228,9 +227,7 @@ ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
end


function _M.access(conf, ctx)
local rets = {ai_proxy.access(conf, ctx)}
return unpack(rets)
end
_M.access = base.new(proxy_request_to_llm, get_model_name)


return _M
101 changes: 5 additions & 96 deletions apisix/plugins/ai-proxy.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@
--
local core = require("apisix.core")
local schema = require("apisix.plugins.ai-proxy.schema")
local base = require("apisix.plugins.ai-proxy.base")

local require = require
local pcall = pcall
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local bad_request = ngx.HTTP_BAD_REQUEST
local ngx_req = ngx.req
local ngx_print = ngx.print
local ngx_flush = ngx.flush

local plugin_name = "ai-proxy"
local _M = {
Expand All @@ -42,24 +39,12 @@ function _M.check_schema(conf)
end


local CONTENT_TYPE_JSON = "application/json"


local function keepalive_or_close(conf, httpc)
if conf.set_keepalive then
httpc:set_keepalive(10000, 100)
return
end
httpc:close()
end


function _M.get_model_name(conf)
local function get_model_name(conf)
return conf.model.name
end


function _M.proxy_request_to_llm(conf, request_table, ctx)
local function proxy_request_to_llm(conf, request_table, ctx)
local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider)
local extra_opts = {
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
Expand All @@ -74,82 +59,6 @@ function _M.proxy_request_to_llm(conf, request_table, ctx)
return res, nil, httpc
end

function _M.access(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

request_table.model = _M.get_model_name(conf)

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end
_M.access = base.new(proxy_request_to_llm, get_model_name)

return _M
117 changes: 117 additions & 0 deletions apisix/plugins/ai-proxy/base.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--

local CONTENT_TYPE_JSON = "application/json"
local core = require("apisix.core")
local bad_request = ngx.HTTP_BAD_REQUEST
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local schema = require("apisix.plugins.ai-proxy.schema")
local ngx_req = ngx.req
local ngx_print = ngx.print
local ngx_flush = ngx.flush

local function keepalive_or_close(conf, httpc)
if conf.set_keepalive then
httpc:set_keepalive(10000, 100)
return
end
httpc:close()
end

local _M = {}

function _M.new(proxy_request_to_llm_func, get_model_name_func)
return function(conf, ctx)
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
return bad_request, "unsupported content-type: " .. ct
end

local request_table, err = core.request.get_json_request_body_table()
if not request_table then
return bad_request, err
end

local ok, err = core.schema.check(schema.chat_request_schema, request_table)
if not ok then
return bad_request, "request format doesn't match schema: " .. err
end

request_table.model = get_model_name_func(conf)

if core.table.try_read_attr(conf, "model", "options", "stream") then
request_table.stream = true
end

local res, err, httpc = proxy_request_to_llm_func(conf, request_table, ctx)
if not res then
core.log.error("failed to send request to LLM service: ", err)
return internal_server_error
end

local body_reader = res.body_reader
if not body_reader then
core.log.error("LLM sent no response body")
return internal_server_error
end

if conf.passthrough then
ngx_req.init_body()
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_req.append_body(chunk)
end
ngx_req.finish_body()
keepalive_or_close(conf, httpc)
return
end

if request_table.stream then
while true do
local chunk, err = body_reader() -- will read chunk by chunk
if err then
core.log.error("failed to read response chunk: ", err)
break
end
if not chunk then
break
end
ngx_print(chunk)
ngx_flush(true)
end
keepalive_or_close(conf, httpc)
return
else
local res_body, err = res:read_body()
if not res_body then
core.log.error("failed to read response body: ", err)
return internal_server_error
end
keepalive_or_close(conf, httpc)
return res.status, res_body
end
end
end

return _M
2 changes: 1 addition & 1 deletion t/plugin/ai-proxy-multi2.t
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ passed


=== TEST 6: send request
--- custom_trusted_cert: /etc/ssl/cert.pem
--- custom_trusted_cert: /etc/ssl/certs/ca-certificates.crt
--- request
POST /anything
{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }
Expand Down
Loading