Skip to content

Commit

Permalink
feat: escape non ascii in json
Browse files Browse the repository at this point in the history
  • Loading branch information
honsunrise committed Apr 9, 2024
1 parent f70ed12 commit 616f637
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
8 changes: 6 additions & 2 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::pin::Pin;

use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest::header::{HeaderValue, CONTENT_TYPE};
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

Expand All @@ -11,6 +12,7 @@ use crate::{
file::Files,
image::Images,
moderation::Moderations,
util::escape_non_ascii_json,
Assistants, Audio, Chat, Completions, Embeddings, FineTuning, Models, Threads,
};

Expand Down Expand Up @@ -196,7 +198,8 @@ impl<C: Config> Client<C> {
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
.body(escape_non_ascii_json(&request)?)
.build()?)
};

Expand All @@ -215,7 +218,8 @@ impl<C: Config> Client<C> {
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.header(CONTENT_TYPE, HeaderValue::from_static("application/json"))
.body(escape_non_ascii_json(&request)?)
.build()?)
};

Expand Down
35 changes: 35 additions & 0 deletions async-openai/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,38 @@ pub(crate) fn create_all_dir<P: AsRef<Path>>(dir: P) -> Result<(), OpenAIError>

Ok(())
}

/// Formatter for serializing JSON with non-ASCII characters escaped.
pub(crate) struct EscapeNonAscii;

impl serde_json::ser::Formatter for EscapeNonAscii {
fn write_string_fragment<W: ?Sized + std::io::Write>(
&mut self,
writer: &mut W,
fragment: &str,
) -> std::io::Result<()> {
for ch in fragment.chars() {
if ch.is_ascii() {
writer.write_all(ch.encode_utf8(&mut [0; 4]).as_bytes())?;
} else {
let mut buf = [0; 2];
let escape = ch.encode_utf16(&mut buf);
write!(writer, "\\u{:04x}\\u{:04x}", escape[0], escape[1])?;
}
}
Ok(())
}
}

/// Serialize the given value to JSON with non-ASCII characters escaped.
pub(crate) fn escape_non_ascii_json<T: serde::Serialize>(
value: &T,
) -> Result<Vec<u8>, OpenAIError> {
let mut writer = Vec::with_capacity(128);
let formatter = EscapeNonAscii;
let mut ser = serde_json::Serializer::with_formatter(&mut writer, formatter);
value
.serialize(&mut ser)
.map_err(|err| OpenAIError::InvalidArgument(format!("{err:#}")))?;
Ok(writer)
}

0 comments on commit 616f637

Please sign in to comment.