diff --git a/apisix/plugins/ai-proxy-multi.lua b/apisix/plugins/ai-proxy-multi.lua index 48f0dea944aa..2ffaa1595b79 100644 --- a/apisix/plugins/ai-proxy-multi.lua +++ b/apisix/plugins/ai-proxy-multi.lua @@ -17,13 +17,12 @@ local core = require("apisix.core") local schema = require("apisix.plugins.ai-proxy.schema") -local ai_proxy = require("apisix.plugins.ai-proxy") local plugin = require("apisix.plugin") +local base = require("apisix.plugins.ai-proxy.base") local require = require local pcall = pcall local ipairs = ipairs -local unpack = unpack local type = type local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR @@ -190,11 +189,11 @@ local function get_load_balanced_provider(ctx, conf, ups_tab, request_table) return provider_name, provider_conf end -ai_proxy.get_model_name = function (...) +local function get_model_name(...) end -ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx) +local function proxy_request_to_llm(conf, request_table, ctx) local ups_tab = {} local algo = core.table.try_read_attr(conf, "balancer", "algorithm") if algo == "chash" then @@ -228,9 +227,7 @@ ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx) end -function _M.access(conf, ctx) - local rets = {ai_proxy.access(conf, ctx)} - return unpack(rets) -end +_M.access = base.new(proxy_request_to_llm, get_model_name) + return _M diff --git a/apisix/plugins/ai-proxy.lua b/apisix/plugins/ai-proxy.lua index c27ca9a3b995..ffc82f85672f 100644 --- a/apisix/plugins/ai-proxy.lua +++ b/apisix/plugins/ai-proxy.lua @@ -16,13 +16,10 @@ -- local core = require("apisix.core") local schema = require("apisix.plugins.ai-proxy.schema") +local base = require("apisix.plugins.ai-proxy.base") + local require = require local pcall = pcall -local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR -local bad_request = ngx.HTTP_BAD_REQUEST -local ngx_req = ngx.req -local ngx_print = ngx.print -local ngx_flush = ngx.flush local plugin_name = "ai-proxy" local _M = { @@ -42,24 +39,12 @@ function _M.check_schema(conf) end -local CONTENT_TYPE_JSON = "application/json" - - -local function keepalive_or_close(conf, httpc) - if conf.set_keepalive then - httpc:set_keepalive(10000, 100) - return - end - httpc:close() -end - - -function _M.get_model_name(conf) +local function get_model_name(conf) return conf.model.name end -function _M.proxy_request_to_llm(conf, request_table, ctx) +local function proxy_request_to_llm(conf, request_table, ctx) local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider) local extra_opts = { endpoint = core.table.try_read_attr(conf, "override", "endpoint"), @@ -74,82 +59,6 @@ function _M.proxy_request_to_llm(conf, request_table, ctx) return res, nil, httpc end -function _M.access(conf, ctx) - local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON - if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then - return bad_request, "unsupported content-type: " .. ct - end - - local request_table, err = core.request.get_json_request_body_table() - if not request_table then - return bad_request, err - end - - local ok, err = core.schema.check(schema.chat_request_schema, request_table) - if not ok then - return bad_request, "request format doesn't match schema: " .. err - end - - request_table.model = _M.get_model_name(conf) - - if core.table.try_read_attr(conf, "model", "options", "stream") then - request_table.stream = true - end - - local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx) - if not res then - core.log.error("failed to send request to LLM service: ", err) - return internal_server_error - end - - local body_reader = res.body_reader - if not body_reader then - core.log.error("LLM sent no response body") - return internal_server_error - end - - if conf.passthrough then - ngx_req.init_body() - while true do - local chunk, err = body_reader() -- will read chunk by chunk - if err then - core.log.error("failed to read response chunk: ", err) - break - end - if not chunk then - break - end - ngx_req.append_body(chunk) - end - ngx_req.finish_body() - keepalive_or_close(conf, httpc) - return - end - - if request_table.stream then - while true do - local chunk, err = body_reader() -- will read chunk by chunk - if err then - core.log.error("failed to read response chunk: ", err) - break - end - if not chunk then - break - end - ngx_print(chunk) - ngx_flush(true) - end - keepalive_or_close(conf, httpc) - return - else - local res_body, err = res:read_body() - if not res_body then - core.log.error("failed to read response body: ", err) - return internal_server_error - end - keepalive_or_close(conf, httpc) - return res.status, res_body - end -end +_M.access = base.new(proxy_request_to_llm, get_model_name) return _M diff --git a/apisix/plugins/ai-proxy/base.lua b/apisix/plugins/ai-proxy/base.lua new file mode 100644 index 000000000000..bd6e945ba788 --- /dev/null +++ b/apisix/plugins/ai-proxy/base.lua @@ -0,0 +1,117 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + +local CONTENT_TYPE_JSON = "application/json" +local core = require("apisix.core") +local bad_request = ngx.HTTP_BAD_REQUEST +local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR +local schema = require("apisix.plugins.ai-proxy.schema") +local ngx_req = ngx.req +local ngx_print = ngx.print +local ngx_flush = ngx.flush + +local function keepalive_or_close(conf, httpc) + if conf.set_keepalive then + httpc:set_keepalive(10000, 100) + return + end + httpc:close() +end + +local _M = {} + +function _M.new(proxy_request_to_llm_func, get_model_name_func) + return function(conf, ctx) + local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON + if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then + return bad_request, "unsupported content-type: " .. ct + end + + local request_table, err = core.request.get_json_request_body_table() + if not request_table then + return bad_request, err + end + + local ok, err = core.schema.check(schema.chat_request_schema, request_table) + if not ok then + return bad_request, "request format doesn't match schema: " .. err + end + + request_table.model = get_model_name_func(conf) + + if core.table.try_read_attr(conf, "model", "options", "stream") then + request_table.stream = true + end + + local res, err, httpc = proxy_request_to_llm_func(conf, request_table, ctx) + if not res then + core.log.error("failed to send request to LLM service: ", err) + return internal_server_error + end + + local body_reader = res.body_reader + if not body_reader then + core.log.error("LLM sent no response body") + return internal_server_error + end + + if conf.passthrough then + ngx_req.init_body() + while true do + local chunk, err = body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + ngx_req.append_body(chunk) + end + ngx_req.finish_body() + keepalive_or_close(conf, httpc) + return + end + + if request_table.stream then + while true do + local chunk, err = body_reader() -- will read chunk by chunk + if err then + core.log.error("failed to read response chunk: ", err) + break + end + if not chunk then + break + end + ngx_print(chunk) + ngx_flush(true) + end + keepalive_or_close(conf, httpc) + return + else + local res_body, err = res:read_body() + if not res_body then + core.log.error("failed to read response body: ", err) + return internal_server_error + end + keepalive_or_close(conf, httpc) + return res.status, res_body + end + end +end + +return _M diff --git a/t/plugin/ai-proxy-multi2.t b/t/plugin/ai-proxy-multi2.t index af5c4e880cb8..9a77dc5f7b67 100644 --- a/t/plugin/ai-proxy-multi2.t +++ b/t/plugin/ai-proxy-multi2.t @@ -289,7 +289,7 @@ passed === TEST 6: send request ---- custom_trusted_cert: /etc/ssl/cert.pem +--- custom_trusted_cert: /etc/ssl/certs/ca-certificates.crt --- request POST /anything { "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }