Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/websocket #544

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"name": "Rust Development",
"image": "mcr.microsoft.com/devcontainers/rust:dev-bullseye",
"mounts": [
{
"source": "dind-var-lib-docker-${devcontainerId}",
"target": "/var/lib/docker",
"type": "volume"
}
],
"customizations": {
"vscode": {
"extensions": [
"rust-lang.rust-analyzer",
"chunsen.bracket-select",
"FittenTech.Fitten-Code",
"tamasfe.even-better-toml"
]
}
},
"forwardPorts": [],
"containerEnv": {
"TZ": "Asia/Shanghai"
},
"postCreateCommand": "cargo build",
"features":{
"ghcr.io/devcontainers/features/docker-in-docker:2": {}
}
}
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ snmalloc-rs = { version = "0.3.4", optional = true }
rpmalloc = { version = "0.2.2", optional = true }
jemallocator = { package = "tikv-jemallocator", version = "0.5.4", optional = true }
mimalloc = { version = "0.1.39", default-features = false, optional = true }
openssl = "0.10.63"

[target.'cfg(target_family = "unix")'.dependencies]
daemonize = "0.5.0"
Expand Down
4 changes: 3 additions & 1 deletion crates/openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pin-project-lite = { version = "0.2.13", optional = true }
nom = { version = "7.1.3", optional = true }
mime = { version = "0.3.17", optional = true }
futures-timer = { version = "3.0.2", optional = true }
tokio-tungstenite = {version = "0.21.0" , features = ["rustls-tls-native-roots"] }
futures-util = "0.3"

# mitm
mitm = { path = "../mitm", optional = true }
Expand All @@ -56,7 +58,7 @@ cbc = "0.1.2"
rand_distr = "0.4.3"

# axum
axum = { version = "0.6.20", features = ["http2", "multipart", "headers"], optional = true }
axum = { version = "0.6.20", features = ["http2", "multipart", "headers","ws"], optional = true }
axum-extra ={ version = "0.8.0", features = ["cookie"], optional = true }
axum-server = { version = "0.5.1", features = ["tls-rustls"], optional = true }
tower-http = { version = "0.4.4", default-features = false, features = ["fs", "cors", "trace", "map-request-body", "util"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions crates/openai/src/arkose/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ impl From<GPTModel> for Type {
match value {
GPTModel::Gpt35 => Type::GPT3,
GPTModel::Gpt4 | GPTModel::Gpt4Mobile => Type::GPT4,
GPTModel::GptGizmo => Type::GPT4,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/openai/src/chatgpt/model/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub enum Action {
#[derive(Serialize, TypedBuilder)]
pub struct ConversationMode<'a> {
pub kind: &'a str,
pub gizmo_id:Option<&'a str>,
}

#[derive(Serialize, TypedBuilder)]
Expand Down
4 changes: 4 additions & 0 deletions crates/openai/src/context/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ pub struct Args {
#[builder(default = false)]
pub(crate) enable_arkose_proxy: bool,

/// websocket endpoint
#[builder(setter(into), default = Some("ws://127.0.0.1:7999".to_string()))]
pub(super) websocket_endpoint: Option<String>,

/// Cloudflare captcha site key
#[builder(setter(into), default)]
pub(crate) cf_site_key: Option<String>,
Expand Down
1 change: 1 addition & 0 deletions crates/openai/src/context/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn init_context(args: Args) -> Context {
arkose_endpoint: args.arkose_endpoint,
arkose_context: ArkoseVersionContext::new(),
arkose_solver: args.arkose_solver,
websocket_endpoint: args.websocket_endpoint,
arkose_gpt3_experiment: args.arkose_gpt3_experiment,
arkose_gpt3_experiment_solver: args.arkose_gpt3_experiment_solver,
arkose_solver_tguess_endpoint: args.arkose_solver_tguess_endpoint,
Expand Down
7 changes: 7 additions & 0 deletions crates/openai/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ pub struct Context {
cf_turnstile: Option<CfTurnstile>,
/// Arkose endpoint
arkose_endpoint: Option<String>,
/// Websocket endpoint
websocket_endpoint: Option<String>,
/// Enable Arkose GPT-3.5 experiment
arkose_gpt3_experiment: bool,
/// Enable Arkose GPT-3.5 experiment solver
Expand Down Expand Up @@ -142,6 +144,11 @@ impl Context {
self.arkose_gpt3_experiment_solver
}

/// Get websocket endpoint
pub fn websocket_endpoint(&self) -> Option<&str> {
self.websocket_endpoint.as_deref()
}

/// Get the arkose context
pub fn arkose_context(&self) -> &arkose::ArkoseVersionContext<'static> {
&self.arkose_context
Expand Down
12 changes: 12 additions & 0 deletions crates/openai/src/gpt_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub enum GPTModel {
Gpt35,
Gpt4,
Gpt4Mobile,
GptGizmo,
}

impl Serialize for GPTModel {
Expand All @@ -15,6 +16,7 @@ impl Serialize for GPTModel {
GPTModel::Gpt35 => "text-davinci-002-render-sha",
GPTModel::Gpt4 => "gpt-4",
GPTModel::Gpt4Mobile => "gpt-4-mobile",
GPTModel::GptGizmo => "gpt-4-gizmo",
};
serializer.serialize_str(model)
}
Expand All @@ -34,6 +36,13 @@ impl GPTModel {
_ => false,
}
}

pub fn is_gizmo(&self) -> bool {
match self {
GPTModel::GptGizmo => true,
_ => false,
}
}
}

impl FromStr for GPTModel {
Expand All @@ -48,6 +57,9 @@ impl FromStr for GPTModel {
{
Ok(GPTModel::Gpt35)
}
s if s.starts_with("g-")=> {
Ok(GPTModel::GptGizmo)
}
// If the model is gpt-4-mobile, we assume it's gpt-4-mobile
"gpt-4-mobile" => Ok(GPTModel::Gpt4Mobile),
// If the model starts with gpt-4, we assume it's gpt-4
Expand Down
31 changes: 30 additions & 1 deletion crates/openai/src/serve/proxy/resp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use axum::http::header;
use axum::response::{IntoResponse, Response};
use axum_extra::extract::cookie;
use axum_extra::extract::cookie::Cookie;
use regex::Regex;
use serde_json::Value;

use crate::serve::error::ResponseError;
Expand Down Expand Up @@ -61,8 +62,36 @@ pub(crate) async fn response_convert(
}
}

// Modify ws endpoint response
if let Some(ws_up_stream) = with_context!(websocket_endpoint) {
if !ws_up_stream.is_empty()
&& resp.inner.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap().eq("application/json")
&& ( resp.inner.url().path().contains("register-websocket") || resp.inner.url().path().ends_with("/backend-api/conversation")) {
let mut json = resp
.inner
.text()
.await
.map_err(ResponseError::InternalServerError)?;

let re = Regex::new(r"wss://([^/]+)/client/hubs/conversations\?").unwrap();

if let Some(caps) = re.captures(&json) {
if let Some(matched) = caps.get(1) {
let matched_str = matched.as_str();
let replacement = format!("{}/client/hubs/conversations?host={}&", ws_up_stream, matched_str);
json = re.replace(&json, replacement.as_str()).to_string();
}
}
return Ok(builder
.body(StreamBody::new(Body::from(json)))
.map_err(ResponseError::InternalServerError)?
.into_response());

}

}
// Modify files endpoint response
if with_context!(enable_file_proxy) && resp.inner.url().path().contains("/backend-api/files") {
if with_context!(enable_file_proxy) && resp.inner.url().path().contains("/backend-api/files") {
let url = resp.inner.url().clone();
// Files endpoint handling
let mut json = resp
Expand Down
61 changes: 48 additions & 13 deletions crates/openai/src/serve/proxy/toapi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::now_duration;
use crate::serve::error::ProxyError;
use crate::serve::ProxyResult;
use crate::token;

use crate::{
arkose::ArkoseToken,
chatgpt::model::req::{Content, ConversationMode, Messages, PostConvoRequest},
Expand Down Expand Up @@ -120,12 +121,23 @@ pub(super) async fn send_request(req: RequestExt) -> Result<ResponseExt, Respons

// Create request
let parent_message_id = uuid();
let kind = if gpt_model.is_gizmo() {
ConversationMode{
kind:"gizmo_interaction",
gizmo_id:Some(body.model.as_str())
}
} else {
ConversationMode{
kind:"primary_assistant",
gizmo_id:None
}
};


let req_body = PostConvoRequest::builder()
.action(Action::Next)
.arkose_token(arkose_token.as_deref())
.conversation_mode(ConversationMode {
kind: "primary_assistant",
})
.conversation_mode(kind)
.force_paragen(false)
.force_rate_limit(false)
.history_and_training_disabled(true)
Expand All @@ -152,6 +164,9 @@ pub(super) async fn send_request(req: RequestExt) -> Result<ResponseExt, Respons
.send()
.await
.map_err(ResponseError::InternalServerError)?;

// Check resp content-type
// will handle sse/wss

Ok(ResponseExt::builder()
.inner(resp)
Expand All @@ -163,6 +178,13 @@ pub(super) async fn send_request(req: RequestExt) -> Result<ResponseExt, Respons
)
.build())
}

#[derive(serde::Deserialize)]
struct WeResp {
wss_url: String,
conversation_id: String ,
}


/// Convert response to ChatGPT API
pub(super) async fn response_convert(
Expand All @@ -174,19 +196,32 @@ pub(super) async fn response_convert(
let config = resp_ext.context.ok_or(ResponseError::InternalServerError(
ProxyError::RequestContentIsEmpty,
))?;

// Get response body event source
let event_source = resp.bytes_stream().eventsource();

if config.stream {
// Create a stream response
let stream = stream::stream_handler(event_source, config.model)?;
// process applicaiton for wss
if resp.headers().get(header::CONTENT_TYPE).unwrap().to_str().unwrap().eq("application/json") {
// Get response body
let body = resp.json::<WeResp>().await?;
// Create a not stream response
let stream = stream::ws_stream_handler(
body.wss_url ,
body.conversation_id ,
config.model).await?;
Ok(Sse::new(stream).into_response())

} else {
// Create a not stream response
let no_stream = stream::not_stream_handler(event_source, config.model).await?;
Ok(no_stream.into_response())

// Get response body event source
let event_source = resp.bytes_stream().eventsource();
if config.stream {
// Create a stream response
let stream = stream::stream_handler(event_source, config.model)?;
Ok(Sse::new(stream).into_response())
} else {
// Create a not stream response
let no_stream = stream::not_stream_handler(event_source, config.model).await?;
Ok(no_stream.into_response())
}
}

}
Err(err) => Ok(handle_error_response(err)?.into_response()),
}
Expand Down
16 changes: 16 additions & 0 deletions crates/openai/src/serve/proxy/toapi/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,19 @@ pub struct Delta<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<&'a str>,
}



#[derive(Deserialize, Default , Clone)]
pub struct WSStreamData {
pub data: Option<WSStreamDataBody>,
#[serde(rename = "type")]
pub msg_type: String,
}

#[derive(Deserialize, Default , Clone)]
pub struct WSStreamDataBody {
pub body: String,
pub conversation_id: String,
pub more_body: bool,
}
Loading
Loading