diff --git a/.gitignore b/.gitignore index e4e78ec..8ced0ed 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,25 @@ target/ .DS_Store ._.DS_Store -config.json +# Docker Files mnt data docker-compose.dev.yml + +# TeleGPT Default Config Files +*.config.json +config.json + +# Visual Studio Code Configurations +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..11f64fc --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,23 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "TeleGPT Nightly", + "cargo": { + "args": ["build", "--bin=telegpt", "--package=telegpt"], + "filter": { + "name": "telegpt", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}", + "env": { + "RUST_BACKTRACE": "full", + "RUST_LOG": "DEBUG" + } + } + ] +} diff --git a/Cargo.lock b/Cargo.lock index 96e9298..6d0c801 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -43,9 +43,9 @@ dependencies = [ [[package]] name = "async-openai" -version = "0.8.0" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c791cd9568241317f49bb3e3a6b595c986a2022428caa16a316be2520d05acf1" +checksum = "25a497fb330310be352a9e30a040804cd3f16f5ae2e57eeedd50c0f530189c88" dependencies = [ "backoff", "base64", @@ -562,6 +562,15 @@ dependencies = [ "slab", ] +[[package]] +name = "getopts" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5" +dependencies = [ + "unicode-width", +] + [[package]] name = "getrandom" version = "0.2.8" @@ -1142,6 +1151,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "pulldown-cmark" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d9cc634bc78768157b5cbfe988ffcd1dcba95cd2b2f03a88316c08c6d00ed63" +dependencies = [ + "bitflags", + "getopts", + "memchr", + "unicase", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -1497,6 +1518,7 @@ dependencies = [ "paste", "pin-project-lite", "pretty_env_logger", + "pulldown-cmark", "rusqlite", "serde", "serde_json", @@ -1770,6 +1792,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-width" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" + [[package]] name = "url" version = "2.3.1" diff --git a/Cargo.toml b/Cargo.toml index 499c814..65b990e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ strip = true [dependencies] teloxide = { version = "0.12", features = ["macros"] } -async-openai = "0.8" +async-openai = "0.9" tokio = { version = "1", features = ["full"] } futures = "0.3" pin-project-lite = "0.2" @@ -34,4 +34,5 @@ env_logger = "0.10" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" paste = "1.0" -clap = { version = "4.0", features = ["derive"] } \ No newline at end of file +clap = { version = "4.0", features = ["derive"] } +pulldown-cmark = "0.9" \ No newline at end of file diff --git a/README.md b/README.md index c932e63..b10df44 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,12 @@ TeleGPT is a Telegram bot based on [**teloxide**](https://github.com/teloxide/te ## Features -🦀 **Lightning fast** with pure Rust codebase.
-📢 **All types of chat** (private and group) supports.
-🚀 **Live streaming tokens** to your message bubble.
-💸 **Token usage** statistic recording and queryable via commands.
-⚙️ **Fully customizable** with file-based configuration.
+🦀 **Lightning fast** with pure Rust codebase. +📢 **All types of chat** (private and group) supports. +🚀 **Live streaming tokens** to your message bubble. +⌨️ **Telegram-flavoured Markdown** rendering supports. +💸 **Token usage** statistic recording and queryable via commands. +⚙️ **Fully customizable** with file-based configuration. ✋ **Admin features** (Beta) and user access control supports. ## Getting TeleGPT diff --git a/src/config.rs b/src/config.rs index ea129e0..b552420 100644 --- a/src/config.rs +++ b/src/config.rs @@ -86,6 +86,13 @@ pub struct Config { #[serde(default = "default_conversation_limit", rename = "conversationLimit")] pub conversation_limit: u64, + /// A boolean value that indicates whether to parse and render the + /// markdown contents. When set to `false`, the raw contents returned + /// from OpenAI will be displayed. This is default to `false`. + /// JSON key: `rendersMarkdown` + #[serde(default = "default_renders_markdown", rename = "rendersMarkdown")] + pub renders_markdown: bool, + /// A path for storing the database, [`None`] for in-memory database. /// JSON key: `databasePath` #[serde(rename = "databasePath")] @@ -142,6 +149,7 @@ define_defaults! { openai_api_timeout: u64 = 10, stream_throttle_interval: u64 = 500, conversation_limit: u64 = 20, + renders_markdown: bool = false, } define_defaults!(I18nStrings { diff --git a/src/modules/chat/markdown.rs b/src/modules/chat/markdown.rs new file mode 100644 index 0000000..6763c7d --- /dev/null +++ b/src/modules/chat/markdown.rs @@ -0,0 +1,498 @@ +use std::marker::PhantomData; + +use pulldown_cmark::{ + CodeBlockKind, CowStr, Event as CmarkEvent, Options as CmarkOptions, Parser as CmarkParser, + Tag as CmarkTag, +}; +use teloxide::types::{MessageEntity, MessageEntityKind}; + +#[derive(Debug, Default)] +pub struct ParsedString { + pub content: String, + pub entities: Vec, +} + +impl ParsedString { + fn with_str(string: &str) -> Self { + Self { + content: string.to_owned(), + entities: vec![], + } + } +} + +enum Event<'a> { + Start(Tag<'a>), + End(Tag<'a>), + Text(CowStr<'a>), + Code(CowStr<'a>), + Break, +} + +#[derive(Clone, Debug)] +enum Tag<'a> { + Paragraph, + Heading(u32), + CodeBlock(Option>), + List(Option), + Item, + Italic, + Bold, + Strikethrough, + Link(CowStr<'a>), + Image(CowStr<'a>), +} + +impl<'a> TryFrom> for Tag<'a> { + type Error = ParserError<'a>; + + fn try_from(value: CmarkTag<'a>) -> Result { + let mapped = match value { + CmarkTag::Paragraph | CmarkTag::BlockQuote => Tag::Paragraph, + CmarkTag::Heading(level, _, _) => Tag::Heading(level as _), + CmarkTag::CodeBlock(code_block_kind) => match code_block_kind { + CodeBlockKind::Indented => Tag::CodeBlock(None), + CodeBlockKind::Fenced(lang) => Tag::CodeBlock(Some(lang)), + }, + CmarkTag::List(first) => Tag::List(first), + CmarkTag::Item => Tag::Item, + CmarkTag::Emphasis => Tag::Italic, + CmarkTag::Strong => Tag::Bold, + CmarkTag::Strikethrough => Tag::Strikethrough, + CmarkTag::Link(_, url, _) => Tag::Link(url), + CmarkTag::Image(_, url, _) => Tag::Image(url), + _ => return Err(ParserError::UnexpectedCmarkTag(value)), + }; + Ok(mapped) + } +} + +impl<'a> TryFrom> for Event<'a> { + type Error = ParserError<'a>; + + fn try_from(value: CmarkEvent<'a>) -> Result { + let mapped = match value { + CmarkEvent::Start(tag) => Event::Start(tag.try_into()?), + CmarkEvent::End(tag) => Event::End(tag.try_into()?), + CmarkEvent::Text(text) => Event::Text(text), + CmarkEvent::Html(text) | CmarkEvent::Code(text) => Event::Code(text), + CmarkEvent::SoftBreak | CmarkEvent::HardBreak => Event::Break, + CmarkEvent::Rule => Event::Text(CowStr::Borrowed("---")), + _ => { + return Err(ParserError::UnexpectedCmarkEvent(value)); + } + }; + Ok(mapped) + } +} + +#[derive(Clone, Debug)] +enum EntityKind { + TelegramEntityKind(MessageEntityKind), + List(Option), +} + +impl<'a> TryFrom<&Tag<'a>> for EntityKind { + type Error = ParserError<'a>; + + fn try_from(value: &Tag<'a>) -> Result { + let mapped = match value { + Tag::List(start) => EntityKind::List(*start), + Tag::CodeBlock(lang) => EntityKind::TelegramEntityKind(MessageEntityKind::Pre { + language: lang.as_ref().map(|lang| lang.to_string()), + }), + Tag::Italic => EntityKind::TelegramEntityKind(MessageEntityKind::Italic), + Tag::Bold => EntityKind::TelegramEntityKind(MessageEntityKind::Bold), + Tag::Strikethrough => EntityKind::TelegramEntityKind(MessageEntityKind::Strikethrough), + Tag::Link(url) | Tag::Image(url) => { + EntityKind::TelegramEntityKind(MessageEntityKind::TextLink { + url: url + .parse() + .map_err(|_| ParserError::InvalidURL(url.clone()))?, + }) + } + _ => { + return Err(ParserError::UnexpectedTag(value.clone())); + } + }; + Ok(mapped) + } +} + +#[derive(Clone, Debug)] +struct Entity { + kind: EntityKind, + start: usize, +} + +const PARAGRAPH_MARGIN: usize = 2; +const LIST_ITEM_MARGIN: usize = 1; + +#[derive(Debug)] +enum ParserError<'input> { + /// Cannot convert the Cmark tag to our tag. + UnexpectedCmarkTag(CmarkTag<'input>), + /// Cannot handle the Cmark event. + UnexpectedCmarkEvent(CmarkEvent<'input>), + /// Cannot parse the given URL string. + InvalidURL(CowStr<'input>), + /// Cannot handle the tag. + UnexpectedTag(Tag<'input>), + /// Meet unmatched entity. The first field is the current entity kind, + /// and the second field is string of the expected kind. + UnmatchedEntity(Option, &'static str), +} + +type ParserEventResult<'input> = Result<(), ParserError<'input>>; + +#[derive(Debug)] +struct ParseState<'p> { + entity_stack: Vec, + parsed_string: ParsedString, + utf16_offset: usize, + prev_block_margin: usize, + phantom: PhantomData<&'p str>, +} + +impl<'p> ParseState<'p> { + fn new() -> Self { + Self { + entity_stack: Vec::new(), + parsed_string: ParsedString::default(), + utf16_offset: 0, + prev_block_margin: 0, + phantom: PhantomData, + } + } + + fn close(self) -> ParsedString { + let Self { + mut parsed_string, + prev_block_margin, + .. + } = self; + + // Trim the redundant trailing margins. + parsed_string + .content + .truncate(parsed_string.content.len() - prev_block_margin); + + parsed_string + } + + #[allow(clippy::result_large_err)] + fn next_state<'input: 'p>(mut self, event: Event<'input>) -> Result> { + match event { + Event::Start(tag) => self.start(tag)?, + Event::End(tag) => self.end(tag)?, + Event::Text(text) => self.text(text), + Event::Code(text) => self.code(text), + Event::Break => self.r#break(), + }; + Ok(self) + } + + #[allow(clippy::result_large_err)] + fn start<'input: 'p>(&mut self, tag: Tag<'input>) -> ParserEventResult<'input> { + match tag { + Tag::Paragraph => {} + Tag::Heading(level) => { + self.push_str(&format!("{} ", "#".repeat(level as _))); + } + Tag::Item => { + let top_entity_kind = self.entity_stack.last().map(|e| &e.kind); + let item_marker = top_entity_kind + .ok_or_else(|| ParserError::UnmatchedEntity(top_entity_kind.cloned(), "List")) + .and_then(|kind| match kind { + EntityKind::List(Some(start)) => Ok(format!("{}. ", start)), + EntityKind::List(None) => Ok("• ".to_owned()), + _ => Err(ParserError::UnmatchedEntity(Some(kind.clone()), "List")), + })?; + self.push_str(&item_marker); + } + ref tag_ref => { + let entity_kind = tag_ref + .try_into() + .map_err(|_| ParserError::UnexpectedTag(tag))?; + self.entity_stack.push(Entity { + kind: entity_kind, + start: self.utf16_offset, + }); + } + } + Ok(()) + } + + #[allow(clippy::result_large_err)] + fn end<'input: 'p>(&mut self, tag: Tag<'input>) -> ParserEventResult<'input> { + match tag { + Tag::Paragraph | Tag::Heading(_) => { + self.push_block(PARAGRAPH_MARGIN); + } + Tag::CodeBlock(_) => { + let Entity { kind, start } = self + .entity_stack + .pop() + .ok_or(ParserError::UnmatchedEntity(None, "Pre"))?; + let has_trailing_newline = if self.parsed_string.content.ends_with('\n') { + // Usually, there will be a newline in the end of the code block. + // We want to take it into consideration when performing collapsing. + self.prev_block_margin = 1; + true + } else { + false + }; + self.parsed_string.entities.push(MessageEntity { + kind: if let EntityKind::TelegramEntityKind( + kind @ MessageEntityKind::Pre { .. }, + ) = kind + { + kind + } else { + return Err(ParserError::UnmatchedEntity(Some(kind), "Pre")); + }, + offset: start, + length: self.utf16_offset - start - (if has_trailing_newline { 1 } else { 0 }), + }); + + self.push_block(PARAGRAPH_MARGIN); + } + Tag::List(_) => { + let Entity { kind, .. } = self + .entity_stack + .pop() + .ok_or(ParserError::UnmatchedEntity(None, "List"))?; + if let EntityKind::List(_) = kind { + self.push_block(PARAGRAPH_MARGIN); + } else { + return Err(ParserError::UnmatchedEntity(Some(kind), "List")); + } + } + Tag::Item => { + if let Some(Entity { + kind: EntityKind::List(maybe_start_number), + .. + }) = self.entity_stack.last_mut() + { + if let Some(start_number) = maybe_start_number { + *start_number += 1; + } + } else { + return Err(ParserError::UnmatchedEntity( + self.entity_stack.last().map(|e| e.kind.clone()), + "List", + )); + } + self.push_block(LIST_ITEM_MARGIN) + } + Tag::Italic | Tag::Bold | Tag::Strikethrough => { + let Entity { kind, start } = self + .entity_stack + .pop() + .ok_or(ParserError::UnmatchedEntity(None, "InlineFormat"))?; + self.parsed_string.entities.push(MessageEntity { + kind: if let EntityKind::TelegramEntityKind(kind) = kind { + // FIXME: continue to validate the `MessageEntityKind`. + kind + } else { + return Err(ParserError::UnmatchedEntity(Some(kind), "InlineFormat")); + }, + offset: start, + length: self.utf16_offset - start, + }); + } + Tag::Link(_) | Tag::Image(_) => { + let Entity { kind, start } = self + .entity_stack + .pop() + .ok_or(ParserError::UnmatchedEntity(None, "LinkOrImage"))?; + + self.parsed_string.entities.push(MessageEntity { + kind: if let EntityKind::TelegramEntityKind( + kind @ MessageEntityKind::TextLink { .. }, + ) = kind + { + kind + } else { + return Err(ParserError::UnmatchedEntity(Some(kind), "LinkOrImage")); + }, + offset: start, + length: self.utf16_offset - start, + }); + } + } + Ok(()) + } + + fn text(&mut self, text: CowStr) { + self.push_str(&text); + } + + fn code(&mut self, text: CowStr) { + let offset = self.utf16_offset; + self.push_str(&text); + self.parsed_string.entities.push(MessageEntity { + kind: MessageEntityKind::Code, + offset, + length: self.utf16_offset - offset, + }); + } + + fn r#break(&mut self) { + self.push_str("\n"); + } + + fn push_str(&mut self, string: &str) { + let utf16_len_inc = string.encode_utf16().count(); + self.parsed_string.content.push_str(string); + self.utf16_offset += utf16_len_inc; + self.prev_block_margin = 0; + } + + fn push_block(&mut self, margin: usize) { + if self.prev_block_margin >= margin { + return; + } + + let this_margin = margin - self.prev_block_margin; + self.push_str(&"\n".repeat(this_margin)); + self.prev_block_margin = margin; + } +} + +#[allow(unused)] +pub fn parse(content: &str) -> ParsedString { + let mut options = CmarkOptions::empty(); + options.insert(CmarkOptions::ENABLE_STRIKETHROUGH); + let mut parser = CmarkParser::new_ext(content, options); + + let result = parser.try_fold(ParseState::new(), |acc, event| { + let mapped_event = Event::try_from(event)?; + acc.next_state(mapped_event) + }); + + match result { + Ok(state) => state.close(), + Err(err) => { + error!("Error while parsing Markdown: {:?}", err); + ParsedString::with_str(content) + } + } +} + +#[cfg(test)] +mod tests { + use teloxide::types::{MessageEntity, MessageEntityKind}; + + use super::*; + + #[test] + fn test_parse_simple() { + let raw = r#"# Heading +- list item 1 +- list item 2 + +Next Paragraph"#; + let expected_content = r#"# Heading + +• list item 1 +• list item 2 + +Next Paragraph"#; + let parsed = parse(raw); + + assert_eq!(parsed.content, expected_content); + } + + #[test] + fn test_parse_paragraph_list() { + let raw = r#"- list item 1 + +- list item 2 + +- list item 3"#; + let expected_content = r#"• list item 1 + +• list item 2 + +• list item 3"#; + let parsed = parse(raw); + + assert_eq!(parsed.content, expected_content); + } + + #[test] + fn test_code() { + let raw = r#"This is a code snippet: +```c +printf("hello\n"); +``` + +End"#; + let expected_content = r#"This is a code snippet: + +printf("hello\n"); + +End"#; + let parsed = parse(raw); + + assert_eq!(parsed.content, expected_content); + assert!(matches!( + parsed.entities[0], + MessageEntity { + kind: MessageEntityKind::Pre { + language: Some(ref lang) + }, + offset: 25, + length: 19 + } if lang == "c" + )); + } + + #[test] + fn test_inline_formats() { + let raw = r#"this is **bold *bold italic* text**"#; + let expected_content = r#"this is bold bold italic text"#; + let parsed = parse(raw); + + println!("{:#?}", parsed); + assert_eq!(parsed.content, expected_content); + assert!(matches!( + parsed.entities[0], + MessageEntity { + kind: MessageEntityKind::Italic, + offset: 13, + length: 11 + } + )); + assert!(matches!( + parsed.entities[1], + MessageEntity { + kind: MessageEntityKind::Bold, + offset: 8, + length: 21 + } + )); + } + + #[test] + fn test_malformed_url() { + let raw = r#"This is a [link](invalid)"#; + let parsed = parse(raw); + assert_eq!(parsed.content, raw); + } + + #[test] + fn test_codeblock_only() { + let raw = r#"``` +// line 1 +// line 2 +```"#; + let expected_content = r#"// line 1 +// line 2"#; + let parsed = parse(raw); + + assert_eq!(parsed.content, expected_content); + assert_eq!(parsed.entities[0].length, 19); + } +} diff --git a/src/modules/chat/mod.rs b/src/modules/chat/mod.rs index cd89868..3f618e3 100644 --- a/src/modules/chat/mod.rs +++ b/src/modules/chat/mod.rs @@ -1,6 +1,7 @@ #![allow(clippy::too_many_arguments)] mod braille; +mod markdown; mod openai_client; mod session; mod session_mgr; @@ -146,6 +147,46 @@ async fn handle_retry_action( true } +async fn handle_show_raw_action( + bot: Bot, + query: CallbackQuery, + session_mgr: SessionManager, +) -> bool { + let history_msg_id: Option = query + .data + .as_ref() + .and_then(|data| data.strip_prefix("/show_raw:")) + .and_then(|id_str| id_str.parse().ok()); + if history_msg_id.is_none() { + return false; + } + let history_msg_id = history_msg_id.unwrap(); + + let message = query.message; + if message.is_none() { + return false; + } + let message = message.unwrap(); + let chat_id = message.chat.id; + + let history_message = session_mgr.with_mut_session(chat_id.to_string(), |session| { + session.get_history_message(history_msg_id) + }); + + match history_message { + Some(history_message) => { + let _ = bot + .edit_message_text(chat_id, message.id, history_message.content) + .await; + } + None => { + let _ = bot.send_message(chat_id, "The message is stale.").await; + } + } + + true +} + async fn actually_handle_chat_message( bot: Bot, reply_to_msg: Option, @@ -185,15 +226,65 @@ async fn actually_handle_chat_message( // Record stats and add the reply to history. let reply_result = match result { Ok(res) => { - session_mgr.add_message_to_session(chat_id.clone(), user_msg); - session_mgr.add_message_to_session( - chat_id.clone(), - ChatCompletionRequestMessageArgs::default() - .role(Role::Assistant) - .content(res.content) - .build() - .unwrap(), - ); + let reply_history_message = session_mgr.with_mut_session(chat_id.clone(), |session| { + session.prepare_history_message( + ChatCompletionRequestMessageArgs::default() + .role(Role::Assistant) + .content(&res.content) + .build() + .unwrap(), + ) + }); + + let need_fallback = if config.renders_markdown { + let parsed_content = markdown::parse(&res.content); + #[cfg(debug_assertions)] + { + debug!( + "rendered Markdown contents: {}\ninto: {:#?}", + res.content, parsed_content + ); + } + let mut edit_message_text = bot.edit_message_text( + chat_id.to_owned(), + sent_progress_msg.id, + parsed_content.content, + ); + if !parsed_content.entities.is_empty() { + let show_raw_button = InlineKeyboardButton::callback( + "Show Raw Contents", + format!("/show_raw:{}", reply_history_message.id), + ); + edit_message_text.entities = Some(parsed_content.entities); + edit_message_text.reply_markup = + Some(InlineKeyboardMarkup::default().append_row([show_raw_button])); + } + if let Err(first_trial_err) = edit_message_text.await { + // TODO: test if the error is related to Markdown before + // fallback to raw contents. + error!( + "failed to send message (will fallback to raw contents): {}", + first_trial_err + ); + true + } else { + false + } + } else { + true + }; + + if need_fallback { + bot.edit_message_text(chat_id.to_owned(), sent_progress_msg.id, &res.content) + .await?; + } + + session_mgr.with_mut_session(chat_id.clone(), |session| { + let user_history_msg = session.prepare_history_message(user_msg); + session.add_history_message(user_history_msg); + session.add_history_message(reply_history_message); + }); + // TODO: maybe we need to handle the case that `reply_to_msg` is `None`. if let Some(from_username) = reply_to_msg .as_ref() @@ -289,8 +380,7 @@ async fn stream_model_result( // in stream mode. Therefore we need to estimate it locally. last_response.token_usage = openai_client::estimate_tokens(&last_response.content) + estimated_prompt_tokens; - bot.edit_message_text(chat_id.to_owned(), editing_msg.id, &last_response.content) - .await?; + return Ok(last_response); } @@ -337,8 +427,8 @@ impl Module for Chat { ) .branch( Update::filter_callback_query() - .filter_async(handle_retry_action) - .endpoint(noop_handler), + .branch(dptree::filter_async(handle_retry_action).endpoint(noop_handler)) + .branch(dptree::filter_async(handle_show_raw_action).endpoint(noop_handler)), ) } diff --git a/src/modules/chat/session.rs b/src/modules/chat/session.rs index b92bdc7..dba5638 100644 --- a/src/modules/chat/session.rs +++ b/src/modules/chat/session.rs @@ -1,13 +1,64 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use async_openai::types::{ChatCompletionRequestMessage as Message, Role}; use crate::config::SharedConfig; +#[derive(Debug, Clone)] +pub struct HistoryMessage { + pub id: i64, + pub message: Message, +} + +#[derive(Debug, Default)] +struct HistoryMessagePool { + current_id: i64, + messages: HashMap, + deque: VecDeque, +} + +impl HistoryMessagePool { + fn prepare_message(&mut self, message: Message) -> HistoryMessage { + let (id, _) = self.current_id.overflowing_add(1); + self.current_id = id; + + HistoryMessage { id, message } + } + + fn push_message(&mut self, message: HistoryMessage) { + let id = message.id; + self.messages.insert(id, message); + self.deque.push_back(id); + } + + fn pop_message(&mut self) { + if let Some(evicted_id) = self.deque.pop_front() { + self.messages.remove(&evicted_id); + } + } + + fn clear(&mut self) { + self.deque.clear(); + self.messages.clear(); + } + + fn len(&self) -> usize { + self.deque.len() + } + + fn get_message(&self, id: &i64) -> Option<&HistoryMessage> { + self.messages.get(id) + } + + fn iter(&self) -> impl Iterator + '_ { + self.deque.iter().filter_map(|id| self.messages.get(id)) + } +} + #[derive(Debug)] pub struct Session { system_message: Option, - messages: VecDeque, + history_messages: HistoryMessagePool, pending_message: Option, config: SharedConfig, } @@ -16,7 +67,7 @@ impl Session { pub fn new(config: SharedConfig) -> Self { Self { system_message: None, - messages: VecDeque::with_capacity(6), + history_messages: Default::default(), pending_message: None, config, } @@ -24,26 +75,36 @@ impl Session { pub fn reset(&mut self) { self.system_message = None; - self.messages.clear(); + self.history_messages.clear(); self.pending_message = None; } - pub fn add_message(&mut self, msg: Message) { - if matches!(msg.role, Role::System) { + pub fn prepare_history_message(&mut self, message: Message) -> HistoryMessage { + self.history_messages.prepare_message(message) + } + + pub fn add_history_message(&mut self, message: HistoryMessage) { + if matches!(message.message.role, Role::System) { // Replace the previous system message, we only support // one system message at the same time. - self.system_message = Some(msg); + self.system_message = Some(message.message); return; } - if self.messages.len() >= (self.config.conversation_limit as usize) { - self.messages.pop_front(); + if self.history_messages.len() >= (self.config.conversation_limit as usize) { + self.history_messages.pop_message(); } - self.messages.push_back(msg); + self.history_messages.push_message(message); + } + + pub fn get_history_message(&self, id: i64) -> Option { + self.history_messages + .get_message(&id) + .map(|m| m.message.clone()) } pub fn get_history_messages(&self) -> Vec { - let msg_iter = self.messages.iter().cloned(); + let msg_iter = self.history_messages.iter().map(|m| m.message.clone()); if let Some(sys_msg) = &self.system_message { let prepend = [sys_msg.to_owned()]; prepend.into_iter().chain(msg_iter).collect() diff --git a/src/modules/chat/session_mgr.rs b/src/modules/chat/session_mgr.rs index 8d4a349..4f086c9 100644 --- a/src/modules/chat/session_mgr.rs +++ b/src/modules/chat/session_mgr.rs @@ -31,10 +31,6 @@ impl SessionManager { self.with_mut_session(key, |session| session.reset()); } - pub fn add_message_to_session(&self, key: String, msg: Message) { - self.with_mut_session(key, |session| session.add_message(msg)); - } - pub fn get_history_messages(&self, key: &str) -> Vec { self.with_mut_inner(|inner| { inner @@ -53,15 +49,7 @@ impl SessionManager { self.with_mut_session(key, |session| session.swap_pending_message(msg)) } - fn with_mut_inner(&self, f: F) -> R - where - F: FnOnce(&mut SessionManagerInner) -> R, - { - let mut inner_mut = self.inner.lock().unwrap(); - f(&mut inner_mut) - } - - fn with_mut_session(&self, key: String, f: F) -> R + pub fn with_mut_session(&self, key: String, f: F) -> R where F: FnOnce(&mut Session) -> R, { @@ -73,6 +61,14 @@ impl SessionManager { f(session_mut) }) } + + fn with_mut_inner(&self, f: F) -> R + where + F: FnOnce(&mut SessionManagerInner) -> R, + { + let mut inner_mut = self.inner.lock().unwrap(); + f(&mut inner_mut) + } } impl Clone for SessionManager {