From 309aa9d2d955e7ba69fa4abf1395411c3b2358f2 Mon Sep 17 00:00:00 2001 From: abc Date: Mon, 5 Feb 2024 16:56:48 +0800 Subject: [PATCH 1/7] feat: Create devcontainer.json --- .devcontainer/devcontainer.json | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 000000000..55aea8bc1 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,29 @@ +{ + "name": "Rust Development", + "image": "mcr.microsoft.com/devcontainers/rust:dev-bullseye", + "mounts": [ + { + "source": "dind-var-lib-docker-${devcontainerId}", + "target": "/var/lib/docker", + "type": "volume" + } + ], + "customizations": { + "vscode": { + "extensions": [ + "rust-lang.rust-analyzer", + "chunsen.bracket-select", + "FittenTech.Fitten-Code", + "tamasfe.even-better-toml" + ] + } + }, + "forwardPorts": [], + "containerEnv": { + "TZ": "Asia/Shanghai" + }, + "postCreateCommand": "cargo build", + "features":{ + "ghcr.io/devcontainers/features/docker-in-docker:2": {} + } + } From 289f6fcaf999af154a26ec4d6acb2a6a7d11d75b Mon Sep 17 00:00:00 2001 From: abc Date: Mon, 19 Feb 2024 19:13:11 +0800 Subject: [PATCH 2/7] feat(wss): added wss connection proxy feat(wss): added topai wss --- Cargo.toml | 1 + crates/openai/Cargo.toml | 4 +- crates/openai/src/serve/proxy/toapi/mod.rs | 47 ++++++-- crates/openai/src/serve/proxy/toapi/model.rs | 9 ++ crates/openai/src/serve/proxy/toapi/stream.rs | 114 +++++++++++++++++- crates/openai/src/serve/router/chat/mod.rs | 85 ++++++++++++- 6 files changed, 246 insertions(+), 14 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 016544de1..e20630dca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ snmalloc-rs = { version = "0.3.4", optional = true } rpmalloc = { version = "0.2.2", optional = true } jemallocator = { package = "tikv-jemallocator", version = "0.5.4", optional = true } mimalloc = { version = "0.1.39", default-features = false, optional = true } +openssl = "0.10.63" [target.'cfg(target_family = "unix")'.dependencies] daemonize = "0.5.0" diff --git a/crates/openai/Cargo.toml b/crates/openai/Cargo.toml index 3fd388726..129fc60d5 100644 --- a/crates/openai/Cargo.toml +++ b/crates/openai/Cargo.toml @@ -45,6 +45,8 @@ pin-project-lite = { version = "0.2.13", optional = true } nom = { version = "7.1.3", optional = true } mime = { version = "0.3.17", optional = true } futures-timer = { version = "3.0.2", optional = true } +tokio-tungstenite = {version = "0.21.0" , features = ["rustls-tls-native-roots"] } +futures-util = "0.3" # mitm mitm = { path = "../mitm", optional = true } @@ -56,7 +58,7 @@ cbc = "0.1.2" rand_distr = "0.4.3" # axum -axum = { version = "0.6.20", features = ["http2", "multipart", "headers"], optional = true } +axum = { version = "0.6.20", features = ["http2", "multipart", "headers","ws"], optional = true } axum-extra ={ version = "0.8.0", features = ["cookie"], optional = true } axum-server = { version = "0.5.1", features = ["tls-rustls"], optional = true } tower-http = { version = "0.4.4", default-features = false, features = ["fs", "cors", "trace", "map-request-body", "util"], optional = true } diff --git a/crates/openai/src/serve/proxy/toapi/mod.rs b/crates/openai/src/serve/proxy/toapi/mod.rs index 2237be492..3c8016fb6 100644 --- a/crates/openai/src/serve/proxy/toapi/mod.rs +++ b/crates/openai/src/serve/proxy/toapi/mod.rs @@ -19,6 +19,7 @@ use crate::now_duration; use crate::serve::error::ProxyError; use crate::serve::ProxyResult; use crate::token; + use crate::{ arkose::ArkoseToken, chatgpt::model::req::{Content, ConversationMode, Messages, PostConvoRequest}, @@ -49,6 +50,7 @@ const SUGGESTIONS: [&'static str; 4] = [ /// Check if the request is supported pub(super) fn support(req: &RequestExt) -> bool { + print!( "support req: {} {}", req.uri.path() , req.method.as_str()); if req.uri.path().eq("/v1/chat/completions") && req.method.eq(&Method::POST) { if let Some(ref token) = req.bearer_auth() { return !token::check_sk_or_sess(token); @@ -152,6 +154,9 @@ pub(super) async fn send_request(req: RequestExt) -> Result Result().await?; + // Create a not stream response + let stream = stream::ws_stream_handler( + body.wss_url , + body.conversation_id , + config.model).await?; Ok(Sse::new(stream).into_response()) + } else { - // Create a not stream response - let no_stream = stream::not_stream_handler(event_source, config.model).await?; - Ok(no_stream.into_response()) + + // Get response body event source + let event_source = resp.bytes_stream().eventsource(); + if config.stream { + // Create a stream response + let stream = stream::stream_handler(event_source, config.model)?; + Ok(Sse::new(stream).into_response()) + } else { + // Create a not stream response + let no_stream = stream::not_stream_handler(event_source, config.model).await?; + Ok(no_stream.into_response()) + } } + } Err(err) => Ok(handle_error_response(err)?.into_response()), } diff --git a/crates/openai/src/serve/proxy/toapi/model.rs b/crates/openai/src/serve/proxy/toapi/model.rs index c2b700573..80b10167d 100644 --- a/crates/openai/src/serve/proxy/toapi/model.rs +++ b/crates/openai/src/serve/proxy/toapi/model.rs @@ -59,3 +59,12 @@ pub struct Delta<'a> { #[serde(skip_serializing_if = "Option::is_none")] pub content: Option<&'a str>, } + + + +#[derive(Deserialize, Default , Clone)] +pub struct WSStreamData { + pub body: String, + pub conversation_id: String, + pub more_body: bool, +} diff --git a/crates/openai/src/serve/proxy/toapi/stream.rs b/crates/openai/src/serve/proxy/toapi/stream.rs index 1eb2fb05d..bc8ad1c1f 100644 --- a/crates/openai/src/serve/proxy/toapi/stream.rs +++ b/crates/openai/src/serve/proxy/toapi/stream.rs @@ -1,16 +1,20 @@ use axum::response::sse::Event; use axum::Json; + + +use base64::{Engine as _, engine::general_purpose}; use eventsource_stream::EventStream; use futures_core::Stream; +use futures_util::StreamExt; use serde_json::Value; use std::convert::Infallible; -use tokio_stream::StreamExt; use crate::chatgpt::model::resp::{ConvoResponse, PostConvoResponse}; use crate::chatgpt::model::Role; use crate::serve::error::{ProxyError, ResponseError}; use crate::serve::ProxyResult; use crate::warn; +use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; use super::model; @@ -43,6 +47,9 @@ fn should_skip_conversion(convo: &ConvoResponse, pin_message_id: &str) -> bool { role_check || metadata_check } + + +// 这里通过websocket转发获取数据 pub(super) fn stream_handler( mut event_soure: EventStream< impl Stream> + std::marker::Unpin, @@ -103,6 +110,111 @@ pub(super) fn stream_handler( Ok(stream) } + +// convert tungstenite message to string +fn from_tungstenite(message: Message) -> String { + match message { + Message::Text(text) => { + let data = serde_json::from_str::(&text).unwrap(); + let body = data.body; + let decoded = general_purpose::STANDARD.decode(&body).unwrap(); + let result_data = String::from_utf8(decoded).unwrap() ; + if result_data.starts_with("data: ") { + let data_index = result_data.find("data: ").unwrap() + 6; + let data_end_index = result_data.find("\n\n").unwrap(); + let data_str = result_data[data_index..data_end_index].to_string(); + return data_str ; + } + return result_data ; + + }, + Message::Binary(_binary) => "".to_owned(), + Message::Ping(_ping) => "".to_owned(), + Message::Pong(_pong) => "".to_owned(), + Message::Close(Some(_close)) => "".to_owned(), + Message::Close(None) => "".to_owned(), + Message::Frame(_) => "".to_owned(), + } +} + + + + + + +// 这里通过websocket转发获取数据 +pub(super) async fn ws_stream_handler( + socket_url:String, + _conversation_id:String, + model: String, +) -> Result>, ResponseError> { + let id = super::generate_id(29); + let timestamp = super::current_timestamp()?; + let (ws_stream, _) = connect_async(socket_url.clone()).await.expect( format!("Failed to connect to {}", socket_url.clone()).as_str()); + + let (mut _write, mut read) = ws_stream.split(); + + + let stream = async_stream::stream! { + let mut previous_message = String::new(); + let mut pin_message_id = String::new(); + let mut set_role = true; + let mut stop: u8 = 0; + + while let Some(data) = read.next().await { + match data { + Ok(message) => { + let message_data = from_tungstenite(message); + if message_data.eq("[DONE]") { + yield Ok(Event::default().data(message_data)); + break; + } + // empty message means skip this message + if message_data.eq("") { + continue; + } + if let Ok(res) = serde_json::from_str::(&message_data) { + if let PostConvoResponse::Conversation(convo) = res { + + // Skip if role is not assistant + if should_skip_conversion(&convo, &pin_message_id) { + continue; + } + // Skip if conversation_id is not equal to conversation_id + if convo.conversation_id() != _conversation_id { + continue; + } + + let mut context = HandlerContext { + stop: &mut stop, + id: &id, + timestamp: ×tamp, + model: &model, + previous_message: &mut previous_message, + pin_message_id: &mut pin_message_id, + set_role: &mut set_role, + }; + + if let Ok(event) = event_convert_handler(&mut context, convo).await { + if stop == 0 || stop <= 1 { + yield Ok(event); + } + } + } + } + }, + Err(err) => { + warn!("event-source stream error: {}", err); + break; + } + } + + } + + }; + Ok(stream) +} + async fn event_convert_handler( context: &mut HandlerContext<'_>, convo: ConvoResponse, diff --git a/crates/openai/src/serve/router/chat/mod.rs b/crates/openai/src/serve/router/chat/mod.rs index cf09a9839..8c1bd5fbb 100644 --- a/crates/openai/src/serve/router/chat/mod.rs +++ b/crates/openai/src/serve/router/chat/mod.rs @@ -26,9 +26,16 @@ use axum_csrf::CsrfLayer; use axum_csrf::CsrfToken; use axum_csrf::Key; use axum_extra::extract::CookieJar; + +use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message , CloseFrame}; +use futures_util::{stream::StreamExt, sink::SinkExt}; +use std::net::SocketAddr; +use tokio_tungstenite::connect_async; + + use serde_json::{json, Value}; use std::collections::HashMap; -use std::net::SocketAddr; + use std::sync::OnceLock; use tower::ServiceBuilder; use tower_http::ServiceBuilderExt; @@ -55,7 +62,9 @@ use crate::{ token::model::Token, URL_CHATGPT_API, }; - +use tokio_tungstenite::tungstenite::{ + self as ts + }; use super::get_static_resource; use session::session::Session; use session::SessionExt; @@ -141,9 +150,81 @@ pub(super) fn config(router: Router, args: &Args) -> Router { .route("/fonts/*path", get(get_static_resource)) .route("/ulp/*path", get(get_static_resource)) .route("/sweetalert2/*path", get(get_static_resource)) + + // wss proxy + .route("/client/hubs/conversations",get(proxy_ws)) + + // 404 endpoint .fallback(error_404) } +// 定义一个用于接收查询参数的结构体 +#[derive(serde::Deserialize)] +struct WsQuery { + access_token: String, + host: String, +} + +async fn proxy_ws( + Query(query): Query, + ws: WebSocketUpgrade +) -> impl IntoResponse { + ws.protocols(["json.reliable.webpubsub.azure.v1"]).on_upgrade(move |socket| handle_socket(socket ,query.host, query.access_token)) +} +async fn handle_socket(socket: WebSocket , host:String, access_token: String) { + // 目标WebSocket服务器地址 + let base_url = format!("wss://{}/client/hubs/conversations?access_token={}" ,host, access_token) ; + let (target_ws, _) = connect_async(base_url.clone()).await.expect( format!("Failed to connect to {}", base_url.clone().as_str() ).as_str()); + let (mut client_sender, mut client_receiver) = socket.split(); + let (mut server_sender, mut server_receiver) = target_ws.split(); + // let connection_info = r#"{"type":"system","event":"connected","userId":"user-UJwFGo4Uv7qmR4zGwh5lFbHz","connectionId":"ZyritNFum8mCfAh3Otr9Ugi2kbOQD02","reconnectionToken":"eyJhbGciOiJIUzI1NiIsImtpZCI6IjI2ODE5MjEwIiwidHlwIjoiSldUIn0.eyJuYmYiOjE3MDgyNzk3NDUsImV4cCI6MTcwODg4NDU0NSwiaWF0IjoxNzA4Mjc5NzQ1LCJhdWQiOiJaeXJpdE5GdW04bUNmQWgzT3RyOVVnaTJrYk9RRDAyIn0.0G760UWJuZRJKEJ9YpsTsl4cA-MMNcwyyyTR27xQUro"}"#; + // 将服务器消息转发到客户端 + let server_to_client = async move { + while let Some(Ok(msg)) = server_receiver.next().await { + client_sender.send(from_tungstenite(msg).unwrap()).await.expect("Failed to send message to client"); + } + }; + // 将客户端消息转发到服务器 + let client_to_server = async move { + while let Some(Ok(msg)) = client_receiver.next().await { + server_sender.send(into_tungstenite(msg)).await.expect("Failed to send message to server"); + } + }; + tokio::join!(client_to_server, server_to_client); +} + + + +fn into_tungstenite(msg:Message) -> ts::Message { + match msg { + Message::Text(text) => ts::Message::Text(text), + Message::Binary(binary) => ts::Message::Binary(binary), + Message::Ping(ping) => ts::Message::Ping(ping), + Message::Pong(pong) => ts::Message::Pong(pong), + Message::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame { + code: ts::protocol::frame::coding::CloseCode::from(close.code), + reason: close.reason, + })), + Message::Close(None) => ts::Message::Close(None), + } +} + +fn from_tungstenite(message: ts::Message) -> Option { + match message { + ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")), + ts::Message::Binary(binary) => Some(Message::Binary(binary)), + ts::Message::Ping(ping) => Some(Message::Ping(ping)), + ts::Message::Pong(pong) => Some(Message::Pong(pong)), + ts::Message::Close(Some(close)) => Some(Message::Close(Some(CloseFrame { + code: close.code.into(), + reason: close.reason, + }))), + ts::Message::Close(None) => Some(Message::Close(None)), + // we can ignore `Frame` frames as recommended by the tungstenite maintainers + // https://github.com/snapview/tungstenite-rs/issues/268 + ts::Message::Frame(_) => None, + } +} /// Forwards the request to the auth provider async fn auth(token: CsrfToken) -> Result { From a139c83b2645f38335785932b9f5722bac12482a Mon Sep 17 00:00:00 2001 From: abc Date: Mon, 19 Feb 2024 20:16:31 +0800 Subject: [PATCH 3/7] feat(toapi): chat2api gizmo support --- crates/openai/src/arkose/mod.rs | 1 + crates/openai/src/chatgpt/model/req.rs | 1 + crates/openai/src/gpt_model.rs | 12 ++++++++++++ crates/openai/src/serve/proxy/toapi/mod.rs | 20 +++++++++++++++----- 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/crates/openai/src/arkose/mod.rs b/crates/openai/src/arkose/mod.rs index 78d6b0421..05ae9f382 100644 --- a/crates/openai/src/arkose/mod.rs +++ b/crates/openai/src/arkose/mod.rs @@ -121,6 +121,7 @@ impl From for Type { match value { GPTModel::Gpt35 => Type::GPT3, GPTModel::Gpt4 | GPTModel::Gpt4Mobile => Type::GPT4, + GPTModel::GptGizmo => Type::GPT4, } } } diff --git a/crates/openai/src/chatgpt/model/req.rs b/crates/openai/src/chatgpt/model/req.rs index 3384871c2..fc8aeaef7 100644 --- a/crates/openai/src/chatgpt/model/req.rs +++ b/crates/openai/src/chatgpt/model/req.rs @@ -38,6 +38,7 @@ pub enum Action { #[derive(Serialize, TypedBuilder)] pub struct ConversationMode<'a> { pub kind: &'a str, + pub gizmo_id:Option<&'a str>, } #[derive(Serialize, TypedBuilder)] diff --git a/crates/openai/src/gpt_model.rs b/crates/openai/src/gpt_model.rs index 1052f8cf9..37707f14f 100644 --- a/crates/openai/src/gpt_model.rs +++ b/crates/openai/src/gpt_model.rs @@ -7,6 +7,7 @@ pub enum GPTModel { Gpt35, Gpt4, Gpt4Mobile, + GptGizmo, } impl Serialize for GPTModel { @@ -15,6 +16,7 @@ impl Serialize for GPTModel { GPTModel::Gpt35 => "text-davinci-002-render-sha", GPTModel::Gpt4 => "gpt-4", GPTModel::Gpt4Mobile => "gpt-4-mobile", + GPTModel::GptGizmo => "gpt-4-gizmo", }; serializer.serialize_str(model) } @@ -34,6 +36,13 @@ impl GPTModel { _ => false, } } + + pub fn is_gizmo(&self) -> bool { + match self { + GPTModel::GptGizmo => true, + _ => false, + } + } } impl FromStr for GPTModel { @@ -48,6 +57,9 @@ impl FromStr for GPTModel { { Ok(GPTModel::Gpt35) } + s if s.starts_with("g-")=> { + Ok(GPTModel::GptGizmo) + } // If the model is gpt-4-mobile, we assume it's gpt-4-mobile "gpt-4-mobile" => Ok(GPTModel::Gpt4Mobile), // If the model starts with gpt-4, we assume it's gpt-4 diff --git a/crates/openai/src/serve/proxy/toapi/mod.rs b/crates/openai/src/serve/proxy/toapi/mod.rs index 3c8016fb6..6828e6ad1 100644 --- a/crates/openai/src/serve/proxy/toapi/mod.rs +++ b/crates/openai/src/serve/proxy/toapi/mod.rs @@ -50,7 +50,6 @@ const SUGGESTIONS: [&'static str; 4] = [ /// Check if the request is supported pub(super) fn support(req: &RequestExt) -> bool { - print!( "support req: {} {}", req.uri.path() , req.method.as_str()); if req.uri.path().eq("/v1/chat/completions") && req.method.eq(&Method::POST) { if let Some(ref token) = req.bearer_auth() { return !token::check_sk_or_sess(token); @@ -122,12 +121,23 @@ pub(super) async fn send_request(req: RequestExt) -> Result().await?; + println!("receive body: {:?}", body.wss_url) ; // Create a not stream response let stream = stream::ws_stream_handler( body.wss_url , From a4e5dd0fe3d40f95e439fa8d872c5e52cff6ad78 Mon Sep 17 00:00:00 2001 From: abc Date: Mon, 19 Feb 2024 20:39:46 +0800 Subject: [PATCH 4/7] fix: remove some comment --- crates/openai/src/serve/proxy/toapi/stream.rs | 3 +-- crates/openai/src/serve/router/chat/mod.rs | 4 ---- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/crates/openai/src/serve/proxy/toapi/stream.rs b/crates/openai/src/serve/proxy/toapi/stream.rs index bc8ad1c1f..09216df4d 100644 --- a/crates/openai/src/serve/proxy/toapi/stream.rs +++ b/crates/openai/src/serve/proxy/toapi/stream.rs @@ -49,7 +49,6 @@ fn should_skip_conversion(convo: &ConvoResponse, pin_message_id: &str) -> bool { -// 这里通过websocket转发获取数据 pub(super) fn stream_handler( mut event_soure: EventStream< impl Stream> + std::marker::Unpin, @@ -142,7 +141,7 @@ fn from_tungstenite(message: Message) -> String { -// 这里通过websocket转发获取数据 +// process webvscoket data and convert to event pub(super) async fn ws_stream_handler( socket_url:String, _conversation_id:String, diff --git a/crates/openai/src/serve/router/chat/mod.rs b/crates/openai/src/serve/router/chat/mod.rs index 8c1bd5fbb..827d17213 100644 --- a/crates/openai/src/serve/router/chat/mod.rs +++ b/crates/openai/src/serve/router/chat/mod.rs @@ -172,19 +172,15 @@ async fn proxy_ws( ws.protocols(["json.reliable.webpubsub.azure.v1"]).on_upgrade(move |socket| handle_socket(socket ,query.host, query.access_token)) } async fn handle_socket(socket: WebSocket , host:String, access_token: String) { - // 目标WebSocket服务器地址 let base_url = format!("wss://{}/client/hubs/conversations?access_token={}" ,host, access_token) ; let (target_ws, _) = connect_async(base_url.clone()).await.expect( format!("Failed to connect to {}", base_url.clone().as_str() ).as_str()); let (mut client_sender, mut client_receiver) = socket.split(); let (mut server_sender, mut server_receiver) = target_ws.split(); - // let connection_info = r#"{"type":"system","event":"connected","userId":"user-UJwFGo4Uv7qmR4zGwh5lFbHz","connectionId":"ZyritNFum8mCfAh3Otr9Ugi2kbOQD02","reconnectionToken":"eyJhbGciOiJIUzI1NiIsImtpZCI6IjI2ODE5MjEwIiwidHlwIjoiSldUIn0.eyJuYmYiOjE3MDgyNzk3NDUsImV4cCI6MTcwODg4NDU0NSwiaWF0IjoxNzA4Mjc5NzQ1LCJhdWQiOiJaeXJpdE5GdW04bUNmQWgzT3RyOVVnaTJrYk9RRDAyIn0.0G760UWJuZRJKEJ9YpsTsl4cA-MMNcwyyyTR27xQUro"}"#; - // 将服务器消息转发到客户端 let server_to_client = async move { while let Some(Ok(msg)) = server_receiver.next().await { client_sender.send(from_tungstenite(msg).unwrap()).await.expect("Failed to send message to client"); } }; - // 将客户端消息转发到服务器 let client_to_server = async move { while let Some(Ok(msg)) = client_receiver.next().await { server_sender.send(into_tungstenite(msg)).await.expect("Failed to send message to server"); From 23bd01a25bb488b5cc61e886cba07fc8251cdb0d Mon Sep 17 00:00:00 2001 From: abc Date: Mon, 19 Feb 2024 22:07:41 +0800 Subject: [PATCH 5/7] fix: added websocket_endpoint , auto box endpoint --- crates/openai/src/context/args.rs | 4 +++ crates/openai/src/context/init.rs | 1 + crates/openai/src/context/mod.rs | 7 +++++ crates/openai/src/serve/proxy/resp.rs | 31 +++++++++++++++++++++- crates/openai/src/serve/proxy/toapi/mod.rs | 4 +-- src/args.rs | 4 +++ src/daemon.rs | 2 ++ src/parse.rs | 10 +++++++ src/store/conf.rs | 2 ++ 9 files changed, 61 insertions(+), 4 deletions(-) diff --git a/crates/openai/src/context/args.rs b/crates/openai/src/context/args.rs index 36a8f5f54..f5cf4e3d1 100644 --- a/crates/openai/src/context/args.rs +++ b/crates/openai/src/context/args.rs @@ -84,6 +84,10 @@ pub struct Args { #[builder(default = false)] pub(crate) enable_arkose_proxy: bool, + /// websocket endpoint + #[builder(setter(into), default = Some("ws://127.0.0.1:7999".to_string()))] + pub(super) websocket_endpoint: Option, + /// Cloudflare captcha site key #[builder(setter(into), default)] pub(crate) cf_site_key: Option, diff --git a/crates/openai/src/context/init.rs b/crates/openai/src/context/init.rs index 3f1c49347..13cda3934 100644 --- a/crates/openai/src/context/init.rs +++ b/crates/openai/src/context/init.rs @@ -39,6 +39,7 @@ fn init_context(args: Args) -> Context { arkose_endpoint: args.arkose_endpoint, arkose_context: ArkoseVersionContext::new(), arkose_solver: args.arkose_solver, + websocket_endpoint: args.websocket_endpoint, arkose_gpt3_experiment: args.arkose_gpt3_experiment, arkose_gpt3_experiment_solver: args.arkose_gpt3_experiment_solver, arkose_solver_tguess_endpoint: args.arkose_solver_tguess_endpoint, diff --git a/crates/openai/src/context/mod.rs b/crates/openai/src/context/mod.rs index b8dc6546e..4c03af2ef 100644 --- a/crates/openai/src/context/mod.rs +++ b/crates/openai/src/context/mod.rs @@ -60,6 +60,8 @@ pub struct Context { cf_turnstile: Option, /// Arkose endpoint arkose_endpoint: Option, + /// Websocket endpoint + websocket_endpoint: Option, /// Enable Arkose GPT-3.5 experiment arkose_gpt3_experiment: bool, /// Enable Arkose GPT-3.5 experiment solver @@ -142,6 +144,11 @@ impl Context { self.arkose_gpt3_experiment_solver } + /// Get websocket endpoint + pub fn websocket_endpoint(&self) -> Option<&str> { + self.websocket_endpoint.as_deref() + } + /// Get the arkose context pub fn arkose_context(&self) -> &arkose::ArkoseVersionContext<'static> { &self.arkose_context diff --git a/crates/openai/src/serve/proxy/resp.rs b/crates/openai/src/serve/proxy/resp.rs index 3a1799fa3..0c2ea3844 100644 --- a/crates/openai/src/serve/proxy/resp.rs +++ b/crates/openai/src/serve/proxy/resp.rs @@ -9,6 +9,7 @@ use axum::http::header; use axum::response::{IntoResponse, Response}; use axum_extra::extract::cookie; use axum_extra::extract::cookie::Cookie; +use regex::Regex; use serde_json::Value; use crate::serve::error::ResponseError; @@ -61,8 +62,36 @@ pub(crate) async fn response_convert( } } + // Modify ws endpoint response + if let Some(ws_up_stream) = with_context!(websocket_endpoint) { + if !ws_up_stream.is_empty() + && resp.inner.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap().eq("application/json") + && ( resp.inner.url().path().contains("register-websocket") || resp.inner.url().path().ends_with("/backend-api/conversation")) { + let mut json = resp + .inner + .text() + .await + .map_err(ResponseError::InternalServerError)?; + + let re = Regex::new(r"wss://([^/]+)/client/hubs/conversations\?").unwrap(); + + if let Some(caps) = re.captures(&json) { + if let Some(matched) = caps.get(1) { + let matched_str = matched.as_str(); + let replacement = format!("{}/client/hubs/conversations?host={}&", ws_up_stream, matched_str); + json = re.replace(&json, replacement.as_str()).to_string(); + } + } + return Ok(builder + .body(StreamBody::new(Body::from(json))) + .map_err(ResponseError::InternalServerError)? + .into_response()); + + } + + } // Modify files endpoint response - if with_context!(enable_file_proxy) && resp.inner.url().path().contains("/backend-api/files") { + if with_context!(enable_file_proxy) && resp.inner.url().path().contains("/backend-api/files") { let url = resp.inner.url().clone(); // Files endpoint handling let mut json = resp diff --git a/crates/openai/src/serve/proxy/toapi/mod.rs b/crates/openai/src/serve/proxy/toapi/mod.rs index 6828e6ad1..e00d80a4f 100644 --- a/crates/openai/src/serve/proxy/toapi/mod.rs +++ b/crates/openai/src/serve/proxy/toapi/mod.rs @@ -196,12 +196,10 @@ pub(super) async fn response_convert( let config = resp_ext.context.ok_or(ResponseError::InternalServerError( ProxyError::RequestContentIsEmpty, ))?; - // print resp.headers().get(header::CONTENT_TYPE) - // 判断cotnent-type是否是application/json + // process applicaiton for wss if resp.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap().eq("application/json") { // Get response body let body = resp.json::().await?; - println!("receive body: {:?}", body.wss_url) ; // Create a not stream response let stream = stream::ws_stream_handler( body.wss_url , diff --git a/src/args.rs b/src/args.rs index e3acb771c..a8be286cd 100644 --- a/src/args.rs +++ b/src/args.rs @@ -172,6 +172,10 @@ pub struct ServeArgs { /// Arkose endpoint, e.g. https://client-api.arkoselabs.com #[clap(long, value_parser = parse::parse_url)] pub(super) arkose_endpoint: Option, + + /// Websocket endpoint, e.g. wss://server.com/ws + #[clap(long, value_parser = parse::parse_websocket_url , env = "WEBSOCKET_ENDPOINT")] + pub(super) websocket_endpoint: Option, /// Enable Arkose GPT-3.5 experiment #[clap(short = 'E', long, default_value = "false")] diff --git a/src/daemon.rs b/src/daemon.rs index 557b57b0f..dd557a3dd 100644 --- a/src/daemon.rs +++ b/src/daemon.rs @@ -71,6 +71,7 @@ pub(super) fn serve(mut args: ServeArgs, relative_path: bool) -> anyhow::Result< .cf_secret_key(args.cf_secret_key) .enable_webui(args.enable_webui) .arkose_endpoint(args.arkose_endpoint) + .websocket_endpoint(args.websocket_endpoint) .arkose_gpt3_experiment(args.arkose_gpt3_experiment) .arkose_gpt3_experiment_solver(args.arkose_gpt3_experiment_solver) .arkose_solver(arkose_solver) @@ -291,6 +292,7 @@ pub(super) fn generate_template(out: Option) -> anyhow::Result<()> { pkey: PathBuf::from("ca/key.pem"), arkose_gpt3_experiment: false, enable_file_proxy: false, + websocket_endpoint: Some("ws://127.0.0.1:7999".to_string()), proxies: Some(vec![ proxy::Proxy::try_from(("all", "socks5://127.0.0.1:8888".parse::()?))?, proxy::Proxy::try_from(("all", "http://127.0.0.1:8889".parse::()?))?, diff --git a/src/parse.rs b/src/parse.rs index bbaeedbb9..286c502a3 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -22,6 +22,16 @@ pub fn parse_url(s: &str) -> anyhow::Result { _ => anyhow::bail!("Unsupported protocol: {}", protocol), } } +// websocket url parse +pub fn parse_websocket_url(s: &str) -> anyhow::Result { + let url = url::Url::parse(s) + .context("The WebSocket Proxy Url format must be `ws[s]://user:pass@ip:port`")?; + let protocol = url.scheme().to_string(); + match protocol.as_str() { + "wss" | "ws" => Ok(s.to_string()), + _ => anyhow::bail!("Unsupported protocol: {}", protocol), + } +} // proxy proto, format: proto|type, support proto: all/api/auth/arkose, support type: ip/url/cidr pub fn parse_proxies_url(s: &str) -> anyhow::Result> { diff --git a/src/store/conf.rs b/src/store/conf.rs index c191d44e6..7e713b6f3 100644 --- a/src/store/conf.rs +++ b/src/store/conf.rs @@ -15,6 +15,8 @@ pub struct Conf { pub unofficial_api: Option, /// Client proxy. Format: protocol://user:pass@ip:port pub proxy: Option, + /// Config wss endpoint. Format: wss://example.com + pub websocket_endpoint: Option, /// About the solver client by ArkoseLabs pub arkose_solver: Solver, /// About the solver client key by ArkoseLabs From eb7547300c1827debcfa6b0e5fb8c9e7137a789a Mon Sep 17 00:00:00 2001 From: abc Date: Tue, 20 Feb 2024 00:34:20 +0800 Subject: [PATCH 6/7] =?UTF-8?q?fix(ws):=20=E4=BF=AE=E5=A4=8Dwebsocket?= =?UTF-8?q?=E7=9A=84=E6=95=B0=E6=8D=AE=E5=8D=8F=E8=AE=AE=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crates/openai/src/serve/proxy/toapi/model.rs | 7 +++++ crates/openai/src/serve/proxy/toapi/stream.rs | 31 +++++++++++++------ crates/openai/src/serve/router/chat/mod.rs | 15 ++++++--- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/crates/openai/src/serve/proxy/toapi/model.rs b/crates/openai/src/serve/proxy/toapi/model.rs index 80b10167d..87a57ffb9 100644 --- a/crates/openai/src/serve/proxy/toapi/model.rs +++ b/crates/openai/src/serve/proxy/toapi/model.rs @@ -64,6 +64,13 @@ pub struct Delta<'a> { #[derive(Deserialize, Default , Clone)] pub struct WSStreamData { + pub data: Option, + #[serde(rename = "type")] + pub msg_type: String, +} + +#[derive(Deserialize, Default , Clone)] +pub struct WSStreamDataBody { pub body: String, pub conversation_id: String, pub more_body: bool, diff --git a/crates/openai/src/serve/proxy/toapi/stream.rs b/crates/openai/src/serve/proxy/toapi/stream.rs index 09216df4d..814e508d5 100644 --- a/crates/openai/src/serve/proxy/toapi/stream.rs +++ b/crates/openai/src/serve/proxy/toapi/stream.rs @@ -15,6 +15,8 @@ use crate::serve::error::{ProxyError, ResponseError}; use crate::serve::ProxyResult; use crate::warn; use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use super::model; @@ -115,16 +117,20 @@ fn from_tungstenite(message: Message) -> String { match message { Message::Text(text) => { let data = serde_json::from_str::(&text).unwrap(); - let body = data.body; - let decoded = general_purpose::STANDARD.decode(&body).unwrap(); - let result_data = String::from_utf8(decoded).unwrap() ; - if result_data.starts_with("data: ") { - let data_index = result_data.find("data: ").unwrap() + 6; - let data_end_index = result_data.find("\n\n").unwrap(); - let data_str = result_data[data_index..data_end_index].to_string(); - return data_str ; + if data.msg_type.eq("message"){ + let body = data.data.unwrap().body; + let decoded = general_purpose::STANDARD.decode(&body).unwrap(); + let result_data = String::from_utf8(decoded).unwrap() ; + if result_data.starts_with("data: ") { + let data_index = result_data.find("data: ").unwrap() + 6; + let data_end_index = result_data.find("\n\n").unwrap(); + let data_str = result_data[data_index..data_end_index].to_string(); + return data_str ; + } + return result_data ; + } - return result_data ; + return "".to_owned() }, Message::Binary(_binary) => "".to_owned(), @@ -149,7 +155,12 @@ pub(super) async fn ws_stream_handler( ) -> Result>, ResponseError> { let id = super::generate_id(29); let timestamp = super::current_timestamp()?; - let (ws_stream, _) = connect_async(socket_url.clone()).await.expect( format!("Failed to connect to {}", socket_url.clone()).as_str()); + + let mut request = socket_url.into_client_request().unwrap(); + request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications + let (ws_stream, _) = connect_async(request) + .await + .expect("Failed to connect"); let (mut _write, mut read) = ws_stream.split(); diff --git a/crates/openai/src/serve/router/chat/mod.rs b/crates/openai/src/serve/router/chat/mod.rs index 827d17213..5ab98f2f3 100644 --- a/crates/openai/src/serve/router/chat/mod.rs +++ b/crates/openai/src/serve/router/chat/mod.rs @@ -29,9 +29,10 @@ use axum_extra::extract::CookieJar; use axum::extract::ws::{WebSocket, WebSocketUpgrade, Message , CloseFrame}; use futures_util::{stream::StreamExt, sink::SinkExt}; -use std::net::SocketAddr; use tokio_tungstenite::connect_async; - +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use std::net::SocketAddr; +use tokio_tungstenite::tungstenite::http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use serde_json::{json, Value}; use std::collections::HashMap; @@ -173,7 +174,12 @@ async fn proxy_ws( } async fn handle_socket(socket: WebSocket , host:String, access_token: String) { let base_url = format!("wss://{}/client/hubs/conversations?access_token={}" ,host, access_token) ; - let (target_ws, _) = connect_async(base_url.clone()).await.expect( format!("Failed to connect to {}", base_url.clone().as_str() ).as_str()); + + let mut request = base_url.into_client_request().unwrap(); + request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications + let (target_ws, _) = connect_async(request) + .await + .expect("Failed to connect"); let (mut client_sender, mut client_receiver) = socket.split(); let (mut server_sender, mut server_receiver) = target_ws.split(); let server_to_client = async move { @@ -207,7 +213,8 @@ fn into_tungstenite(msg:Message) -> ts::Message { fn from_tungstenite(message: ts::Message) -> Option { match message { - ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")), + //ts::Message::Text(text) => Some(Message::Text( r#"{"type":"message","from":"server","dataType":"json","data":"#.to_string() + text.as_str() + "}")), + ts::Message::Text(text) => Some(Message::Text(text)), ts::Message::Binary(binary) => Some(Message::Binary(binary)), ts::Message::Ping(ping) => Some(Message::Ping(ping)), ts::Message::Pong(pong) => Some(Message::Pong(pong)), From 8a4d5356be8f8fd0821e60dd88f3bec522af55d5 Mon Sep 17 00:00:00 2001 From: abc Date: Tue, 20 Feb 2024 02:52:46 +0800 Subject: [PATCH 7/7] fix(ws): remove exprect --- crates/openai/src/serve/router/chat/mod.rs | 39 +++++++++++++--------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/crates/openai/src/serve/router/chat/mod.rs b/crates/openai/src/serve/router/chat/mod.rs index 5ab98f2f3..d1e4e4de7 100644 --- a/crates/openai/src/serve/router/chat/mod.rs +++ b/crates/openai/src/serve/router/chat/mod.rs @@ -177,22 +177,31 @@ async fn handle_socket(socket: WebSocket , host:String, access_token: String) { let mut request = base_url.into_client_request().unwrap(); request.headers_mut().insert(SEC_WEBSOCKET_PROTOCOL, HeaderValue::from_static("json.reliable.webpubsub.azure.v1")); // Or other modifications - let (target_ws, _) = connect_async(request) - .await - .expect("Failed to connect"); - let (mut client_sender, mut client_receiver) = socket.split(); - let (mut server_sender, mut server_receiver) = target_ws.split(); - let server_to_client = async move { - while let Some(Ok(msg)) = server_receiver.next().await { - client_sender.send(from_tungstenite(msg).unwrap()).await.expect("Failed to send message to client"); - } - }; - let client_to_server = async move { - while let Some(Ok(msg)) = client_receiver.next().await { - server_sender.send(into_tungstenite(msg)).await.expect("Failed to send message to server"); + //let Ok((target_ws, _)) = + + match connect_async(request) .await { + Ok(target_ws) => { + + let (mut client_sender, mut client_receiver) = socket.split(); + let (mut server_sender, mut server_receiver) = target_ws.0.split(); + let server_to_client = async move { + while let Some(Ok(msg)) = server_receiver.next().await { + let _ = client_sender.send(from_tungstenite(msg).unwrap()).await; + } + }; + let client_to_server = async move { + while let Some(Ok(msg)) = client_receiver.next().await { + let _ = server_sender.send(into_tungstenite(msg)).await; + } + }; + tokio::join!(client_to_server, server_to_client); + }, + Err(err) => { + eprintln!("Error connecting to target: {}", err); + return; } - }; - tokio::join!(client_to_server, server_to_client); + } + }