diff --git a/.gitignore b/.gitignore
index e4e78ec..8ced0ed 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,7 +12,25 @@ target/
.DS_Store
._.DS_Store
-config.json
+# Docker Files
mnt
data
docker-compose.dev.yml
+
+# TeleGPT Default Config Files
+*.config.json
+config.json
+
+# Visual Studio Code Configurations
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+!.vscode/*.code-snippets
+
+# Local History for Visual Studio Code
+.history/
+
+# Built Visual Studio Code Extensions
+*.vsix
diff --git a/.vscode/launch.json b/.vscode/launch.json
new file mode 100644
index 0000000..11f64fc
--- /dev/null
+++ b/.vscode/launch.json
@@ -0,0 +1,23 @@
+{
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "type": "lldb",
+ "request": "launch",
+ "name": "TeleGPT Nightly",
+ "cargo": {
+ "args": ["build", "--bin=telegpt", "--package=telegpt"],
+ "filter": {
+ "name": "telegpt",
+ "kind": "bin"
+ }
+ },
+ "args": [],
+ "cwd": "${workspaceFolder}",
+ "env": {
+ "RUST_BACKTRACE": "full",
+ "RUST_LOG": "DEBUG"
+ }
+ }
+ ]
+}
diff --git a/Cargo.lock b/Cargo.lock
index 96e9298..6d0c801 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -43,9 +43,9 @@ dependencies = [
[[package]]
name = "async-openai"
-version = "0.8.0"
+version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c791cd9568241317f49bb3e3a6b595c986a2022428caa16a316be2520d05acf1"
+checksum = "25a497fb330310be352a9e30a040804cd3f16f5ae2e57eeedd50c0f530189c88"
dependencies = [
"backoff",
"base64",
@@ -562,6 +562,15 @@ dependencies = [
"slab",
]
+[[package]]
+name = "getopts"
+version = "0.2.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5"
+dependencies = [
+ "unicode-width",
+]
+
[[package]]
name = "getrandom"
version = "0.2.8"
@@ -1142,6 +1151,18 @@ dependencies = [
"unicode-ident",
]
+[[package]]
+name = "pulldown-cmark"
+version = "0.9.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2d9cc634bc78768157b5cbfe988ffcd1dcba95cd2b2f03a88316c08c6d00ed63"
+dependencies = [
+ "bitflags",
+ "getopts",
+ "memchr",
+ "unicase",
+]
+
[[package]]
name = "quick-error"
version = "1.2.3"
@@ -1497,6 +1518,7 @@ dependencies = [
"paste",
"pin-project-lite",
"pretty_env_logger",
+ "pulldown-cmark",
"rusqlite",
"serde",
"serde_json",
@@ -1770,6 +1792,12 @@ dependencies = [
"tinyvec",
]
+[[package]]
+name = "unicode-width"
+version = "0.1.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
+
[[package]]
name = "url"
version = "2.3.1"
diff --git a/Cargo.toml b/Cargo.toml
index 499c814..65b990e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,7 +21,7 @@ strip = true
[dependencies]
teloxide = { version = "0.12", features = ["macros"] }
-async-openai = "0.8"
+async-openai = "0.9"
tokio = { version = "1", features = ["full"] }
futures = "0.3"
pin-project-lite = "0.2"
@@ -34,4 +34,5 @@ env_logger = "0.10"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
paste = "1.0"
-clap = { version = "4.0", features = ["derive"] }
\ No newline at end of file
+clap = { version = "4.0", features = ["derive"] }
+pulldown-cmark = "0.9"
\ No newline at end of file
diff --git a/README.md b/README.md
index c932e63..b10df44 100644
--- a/README.md
+++ b/README.md
@@ -10,11 +10,12 @@ TeleGPT is a Telegram bot based on [**teloxide**](https://github.com/teloxide/te
## Features
-🦀 **Lightning fast** with pure Rust codebase.
-📢 **All types of chat** (private and group) supports.
-🚀 **Live streaming tokens** to your message bubble.
-💸 **Token usage** statistic recording and queryable via commands.
-⚙️ **Fully customizable** with file-based configuration.
+🦀 **Lightning fast** with pure Rust codebase.
+📢 **All types of chat** (private and group) supports.
+🚀 **Live streaming tokens** to your message bubble.
+⌨️ **Telegram-flavoured Markdown** rendering supports.
+💸 **Token usage** statistic recording and queryable via commands.
+⚙️ **Fully customizable** with file-based configuration.
✋ **Admin features** (Beta) and user access control supports.
## Getting TeleGPT
diff --git a/src/config.rs b/src/config.rs
index ea129e0..b552420 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -86,6 +86,13 @@ pub struct Config {
#[serde(default = "default_conversation_limit", rename = "conversationLimit")]
pub conversation_limit: u64,
+ /// A boolean value that indicates whether to parse and render the
+ /// markdown contents. When set to `false`, the raw contents returned
+ /// from OpenAI will be displayed. This is default to `false`.
+ /// JSON key: `rendersMarkdown`
+ #[serde(default = "default_renders_markdown", rename = "rendersMarkdown")]
+ pub renders_markdown: bool,
+
/// A path for storing the database, [`None`] for in-memory database.
/// JSON key: `databasePath`
#[serde(rename = "databasePath")]
@@ -142,6 +149,7 @@ define_defaults! {
openai_api_timeout: u64 = 10,
stream_throttle_interval: u64 = 500,
conversation_limit: u64 = 20,
+ renders_markdown: bool = false,
}
define_defaults!(I18nStrings {
diff --git a/src/modules/chat/markdown.rs b/src/modules/chat/markdown.rs
new file mode 100644
index 0000000..6763c7d
--- /dev/null
+++ b/src/modules/chat/markdown.rs
@@ -0,0 +1,498 @@
+use std::marker::PhantomData;
+
+use pulldown_cmark::{
+ CodeBlockKind, CowStr, Event as CmarkEvent, Options as CmarkOptions, Parser as CmarkParser,
+ Tag as CmarkTag,
+};
+use teloxide::types::{MessageEntity, MessageEntityKind};
+
+#[derive(Debug, Default)]
+pub struct ParsedString {
+ pub content: String,
+ pub entities: Vec,
+}
+
+impl ParsedString {
+ fn with_str(string: &str) -> Self {
+ Self {
+ content: string.to_owned(),
+ entities: vec![],
+ }
+ }
+}
+
+enum Event<'a> {
+ Start(Tag<'a>),
+ End(Tag<'a>),
+ Text(CowStr<'a>),
+ Code(CowStr<'a>),
+ Break,
+}
+
+#[derive(Clone, Debug)]
+enum Tag<'a> {
+ Paragraph,
+ Heading(u32),
+ CodeBlock(Option>),
+ List(Option),
+ Item,
+ Italic,
+ Bold,
+ Strikethrough,
+ Link(CowStr<'a>),
+ Image(CowStr<'a>),
+}
+
+impl<'a> TryFrom> for Tag<'a> {
+ type Error = ParserError<'a>;
+
+ fn try_from(value: CmarkTag<'a>) -> Result {
+ let mapped = match value {
+ CmarkTag::Paragraph | CmarkTag::BlockQuote => Tag::Paragraph,
+ CmarkTag::Heading(level, _, _) => Tag::Heading(level as _),
+ CmarkTag::CodeBlock(code_block_kind) => match code_block_kind {
+ CodeBlockKind::Indented => Tag::CodeBlock(None),
+ CodeBlockKind::Fenced(lang) => Tag::CodeBlock(Some(lang)),
+ },
+ CmarkTag::List(first) => Tag::List(first),
+ CmarkTag::Item => Tag::Item,
+ CmarkTag::Emphasis => Tag::Italic,
+ CmarkTag::Strong => Tag::Bold,
+ CmarkTag::Strikethrough => Tag::Strikethrough,
+ CmarkTag::Link(_, url, _) => Tag::Link(url),
+ CmarkTag::Image(_, url, _) => Tag::Image(url),
+ _ => return Err(ParserError::UnexpectedCmarkTag(value)),
+ };
+ Ok(mapped)
+ }
+}
+
+impl<'a> TryFrom> for Event<'a> {
+ type Error = ParserError<'a>;
+
+ fn try_from(value: CmarkEvent<'a>) -> Result {
+ let mapped = match value {
+ CmarkEvent::Start(tag) => Event::Start(tag.try_into()?),
+ CmarkEvent::End(tag) => Event::End(tag.try_into()?),
+ CmarkEvent::Text(text) => Event::Text(text),
+ CmarkEvent::Html(text) | CmarkEvent::Code(text) => Event::Code(text),
+ CmarkEvent::SoftBreak | CmarkEvent::HardBreak => Event::Break,
+ CmarkEvent::Rule => Event::Text(CowStr::Borrowed("---")),
+ _ => {
+ return Err(ParserError::UnexpectedCmarkEvent(value));
+ }
+ };
+ Ok(mapped)
+ }
+}
+
+#[derive(Clone, Debug)]
+enum EntityKind {
+ TelegramEntityKind(MessageEntityKind),
+ List(Option),
+}
+
+impl<'a> TryFrom<&Tag<'a>> for EntityKind {
+ type Error = ParserError<'a>;
+
+ fn try_from(value: &Tag<'a>) -> Result {
+ let mapped = match value {
+ Tag::List(start) => EntityKind::List(*start),
+ Tag::CodeBlock(lang) => EntityKind::TelegramEntityKind(MessageEntityKind::Pre {
+ language: lang.as_ref().map(|lang| lang.to_string()),
+ }),
+ Tag::Italic => EntityKind::TelegramEntityKind(MessageEntityKind::Italic),
+ Tag::Bold => EntityKind::TelegramEntityKind(MessageEntityKind::Bold),
+ Tag::Strikethrough => EntityKind::TelegramEntityKind(MessageEntityKind::Strikethrough),
+ Tag::Link(url) | Tag::Image(url) => {
+ EntityKind::TelegramEntityKind(MessageEntityKind::TextLink {
+ url: url
+ .parse()
+ .map_err(|_| ParserError::InvalidURL(url.clone()))?,
+ })
+ }
+ _ => {
+ return Err(ParserError::UnexpectedTag(value.clone()));
+ }
+ };
+ Ok(mapped)
+ }
+}
+
+#[derive(Clone, Debug)]
+struct Entity {
+ kind: EntityKind,
+ start: usize,
+}
+
+const PARAGRAPH_MARGIN: usize = 2;
+const LIST_ITEM_MARGIN: usize = 1;
+
+#[derive(Debug)]
+enum ParserError<'input> {
+ /// Cannot convert the Cmark tag to our tag.
+ UnexpectedCmarkTag(CmarkTag<'input>),
+ /// Cannot handle the Cmark event.
+ UnexpectedCmarkEvent(CmarkEvent<'input>),
+ /// Cannot parse the given URL string.
+ InvalidURL(CowStr<'input>),
+ /// Cannot handle the tag.
+ UnexpectedTag(Tag<'input>),
+ /// Meet unmatched entity. The first field is the current entity kind,
+ /// and the second field is string of the expected kind.
+ UnmatchedEntity(Option, &'static str),
+}
+
+type ParserEventResult<'input> = Result<(), ParserError<'input>>;
+
+#[derive(Debug)]
+struct ParseState<'p> {
+ entity_stack: Vec,
+ parsed_string: ParsedString,
+ utf16_offset: usize,
+ prev_block_margin: usize,
+ phantom: PhantomData<&'p str>,
+}
+
+impl<'p> ParseState<'p> {
+ fn new() -> Self {
+ Self {
+ entity_stack: Vec::new(),
+ parsed_string: ParsedString::default(),
+ utf16_offset: 0,
+ prev_block_margin: 0,
+ phantom: PhantomData,
+ }
+ }
+
+ fn close(self) -> ParsedString {
+ let Self {
+ mut parsed_string,
+ prev_block_margin,
+ ..
+ } = self;
+
+ // Trim the redundant trailing margins.
+ parsed_string
+ .content
+ .truncate(parsed_string.content.len() - prev_block_margin);
+
+ parsed_string
+ }
+
+ #[allow(clippy::result_large_err)]
+ fn next_state<'input: 'p>(mut self, event: Event<'input>) -> Result> {
+ match event {
+ Event::Start(tag) => self.start(tag)?,
+ Event::End(tag) => self.end(tag)?,
+ Event::Text(text) => self.text(text),
+ Event::Code(text) => self.code(text),
+ Event::Break => self.r#break(),
+ };
+ Ok(self)
+ }
+
+ #[allow(clippy::result_large_err)]
+ fn start<'input: 'p>(&mut self, tag: Tag<'input>) -> ParserEventResult<'input> {
+ match tag {
+ Tag::Paragraph => {}
+ Tag::Heading(level) => {
+ self.push_str(&format!("{} ", "#".repeat(level as _)));
+ }
+ Tag::Item => {
+ let top_entity_kind = self.entity_stack.last().map(|e| &e.kind);
+ let item_marker = top_entity_kind
+ .ok_or_else(|| ParserError::UnmatchedEntity(top_entity_kind.cloned(), "List"))
+ .and_then(|kind| match kind {
+ EntityKind::List(Some(start)) => Ok(format!("{}. ", start)),
+ EntityKind::List(None) => Ok("• ".to_owned()),
+ _ => Err(ParserError::UnmatchedEntity(Some(kind.clone()), "List")),
+ })?;
+ self.push_str(&item_marker);
+ }
+ ref tag_ref => {
+ let entity_kind = tag_ref
+ .try_into()
+ .map_err(|_| ParserError::UnexpectedTag(tag))?;
+ self.entity_stack.push(Entity {
+ kind: entity_kind,
+ start: self.utf16_offset,
+ });
+ }
+ }
+ Ok(())
+ }
+
+ #[allow(clippy::result_large_err)]
+ fn end<'input: 'p>(&mut self, tag: Tag<'input>) -> ParserEventResult<'input> {
+ match tag {
+ Tag::Paragraph | Tag::Heading(_) => {
+ self.push_block(PARAGRAPH_MARGIN);
+ }
+ Tag::CodeBlock(_) => {
+ let Entity { kind, start } = self
+ .entity_stack
+ .pop()
+ .ok_or(ParserError::UnmatchedEntity(None, "Pre"))?;
+ let has_trailing_newline = if self.parsed_string.content.ends_with('\n') {
+ // Usually, there will be a newline in the end of the code block.
+ // We want to take it into consideration when performing collapsing.
+ self.prev_block_margin = 1;
+ true
+ } else {
+ false
+ };
+ self.parsed_string.entities.push(MessageEntity {
+ kind: if let EntityKind::TelegramEntityKind(
+ kind @ MessageEntityKind::Pre { .. },
+ ) = kind
+ {
+ kind
+ } else {
+ return Err(ParserError::UnmatchedEntity(Some(kind), "Pre"));
+ },
+ offset: start,
+ length: self.utf16_offset - start - (if has_trailing_newline { 1 } else { 0 }),
+ });
+
+ self.push_block(PARAGRAPH_MARGIN);
+ }
+ Tag::List(_) => {
+ let Entity { kind, .. } = self
+ .entity_stack
+ .pop()
+ .ok_or(ParserError::UnmatchedEntity(None, "List"))?;
+ if let EntityKind::List(_) = kind {
+ self.push_block(PARAGRAPH_MARGIN);
+ } else {
+ return Err(ParserError::UnmatchedEntity(Some(kind), "List"));
+ }
+ }
+ Tag::Item => {
+ if let Some(Entity {
+ kind: EntityKind::List(maybe_start_number),
+ ..
+ }) = self.entity_stack.last_mut()
+ {
+ if let Some(start_number) = maybe_start_number {
+ *start_number += 1;
+ }
+ } else {
+ return Err(ParserError::UnmatchedEntity(
+ self.entity_stack.last().map(|e| e.kind.clone()),
+ "List",
+ ));
+ }
+ self.push_block(LIST_ITEM_MARGIN)
+ }
+ Tag::Italic | Tag::Bold | Tag::Strikethrough => {
+ let Entity { kind, start } = self
+ .entity_stack
+ .pop()
+ .ok_or(ParserError::UnmatchedEntity(None, "InlineFormat"))?;
+ self.parsed_string.entities.push(MessageEntity {
+ kind: if let EntityKind::TelegramEntityKind(kind) = kind {
+ // FIXME: continue to validate the `MessageEntityKind`.
+ kind
+ } else {
+ return Err(ParserError::UnmatchedEntity(Some(kind), "InlineFormat"));
+ },
+ offset: start,
+ length: self.utf16_offset - start,
+ });
+ }
+ Tag::Link(_) | Tag::Image(_) => {
+ let Entity { kind, start } = self
+ .entity_stack
+ .pop()
+ .ok_or(ParserError::UnmatchedEntity(None, "LinkOrImage"))?;
+
+ self.parsed_string.entities.push(MessageEntity {
+ kind: if let EntityKind::TelegramEntityKind(
+ kind @ MessageEntityKind::TextLink { .. },
+ ) = kind
+ {
+ kind
+ } else {
+ return Err(ParserError::UnmatchedEntity(Some(kind), "LinkOrImage"));
+ },
+ offset: start,
+ length: self.utf16_offset - start,
+ });
+ }
+ }
+ Ok(())
+ }
+
+ fn text(&mut self, text: CowStr) {
+ self.push_str(&text);
+ }
+
+ fn code(&mut self, text: CowStr) {
+ let offset = self.utf16_offset;
+ self.push_str(&text);
+ self.parsed_string.entities.push(MessageEntity {
+ kind: MessageEntityKind::Code,
+ offset,
+ length: self.utf16_offset - offset,
+ });
+ }
+
+ fn r#break(&mut self) {
+ self.push_str("\n");
+ }
+
+ fn push_str(&mut self, string: &str) {
+ let utf16_len_inc = string.encode_utf16().count();
+ self.parsed_string.content.push_str(string);
+ self.utf16_offset += utf16_len_inc;
+ self.prev_block_margin = 0;
+ }
+
+ fn push_block(&mut self, margin: usize) {
+ if self.prev_block_margin >= margin {
+ return;
+ }
+
+ let this_margin = margin - self.prev_block_margin;
+ self.push_str(&"\n".repeat(this_margin));
+ self.prev_block_margin = margin;
+ }
+}
+
+#[allow(unused)]
+pub fn parse(content: &str) -> ParsedString {
+ let mut options = CmarkOptions::empty();
+ options.insert(CmarkOptions::ENABLE_STRIKETHROUGH);
+ let mut parser = CmarkParser::new_ext(content, options);
+
+ let result = parser.try_fold(ParseState::new(), |acc, event| {
+ let mapped_event = Event::try_from(event)?;
+ acc.next_state(mapped_event)
+ });
+
+ match result {
+ Ok(state) => state.close(),
+ Err(err) => {
+ error!("Error while parsing Markdown: {:?}", err);
+ ParsedString::with_str(content)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use teloxide::types::{MessageEntity, MessageEntityKind};
+
+ use super::*;
+
+ #[test]
+ fn test_parse_simple() {
+ let raw = r#"# Heading
+- list item 1
+- list item 2
+
+Next Paragraph"#;
+ let expected_content = r#"# Heading
+
+• list item 1
+• list item 2
+
+Next Paragraph"#;
+ let parsed = parse(raw);
+
+ assert_eq!(parsed.content, expected_content);
+ }
+
+ #[test]
+ fn test_parse_paragraph_list() {
+ let raw = r#"- list item 1
+
+- list item 2
+
+- list item 3"#;
+ let expected_content = r#"• list item 1
+
+• list item 2
+
+• list item 3"#;
+ let parsed = parse(raw);
+
+ assert_eq!(parsed.content, expected_content);
+ }
+
+ #[test]
+ fn test_code() {
+ let raw = r#"This is a code snippet:
+```c
+printf("hello\n");
+```
+
+End"#;
+ let expected_content = r#"This is a code snippet:
+
+printf("hello\n");
+
+End"#;
+ let parsed = parse(raw);
+
+ assert_eq!(parsed.content, expected_content);
+ assert!(matches!(
+ parsed.entities[0],
+ MessageEntity {
+ kind: MessageEntityKind::Pre {
+ language: Some(ref lang)
+ },
+ offset: 25,
+ length: 19
+ } if lang == "c"
+ ));
+ }
+
+ #[test]
+ fn test_inline_formats() {
+ let raw = r#"this is **bold *bold italic* text**"#;
+ let expected_content = r#"this is bold bold italic text"#;
+ let parsed = parse(raw);
+
+ println!("{:#?}", parsed);
+ assert_eq!(parsed.content, expected_content);
+ assert!(matches!(
+ parsed.entities[0],
+ MessageEntity {
+ kind: MessageEntityKind::Italic,
+ offset: 13,
+ length: 11
+ }
+ ));
+ assert!(matches!(
+ parsed.entities[1],
+ MessageEntity {
+ kind: MessageEntityKind::Bold,
+ offset: 8,
+ length: 21
+ }
+ ));
+ }
+
+ #[test]
+ fn test_malformed_url() {
+ let raw = r#"This is a [link](invalid)"#;
+ let parsed = parse(raw);
+ assert_eq!(parsed.content, raw);
+ }
+
+ #[test]
+ fn test_codeblock_only() {
+ let raw = r#"```
+// line 1
+// line 2
+```"#;
+ let expected_content = r#"// line 1
+// line 2"#;
+ let parsed = parse(raw);
+
+ assert_eq!(parsed.content, expected_content);
+ assert_eq!(parsed.entities[0].length, 19);
+ }
+}
diff --git a/src/modules/chat/mod.rs b/src/modules/chat/mod.rs
index cd89868..3f618e3 100644
--- a/src/modules/chat/mod.rs
+++ b/src/modules/chat/mod.rs
@@ -1,6 +1,7 @@
#![allow(clippy::too_many_arguments)]
mod braille;
+mod markdown;
mod openai_client;
mod session;
mod session_mgr;
@@ -146,6 +147,46 @@ async fn handle_retry_action(
true
}
+async fn handle_show_raw_action(
+ bot: Bot,
+ query: CallbackQuery,
+ session_mgr: SessionManager,
+) -> bool {
+ let history_msg_id: Option = query
+ .data
+ .as_ref()
+ .and_then(|data| data.strip_prefix("/show_raw:"))
+ .and_then(|id_str| id_str.parse().ok());
+ if history_msg_id.is_none() {
+ return false;
+ }
+ let history_msg_id = history_msg_id.unwrap();
+
+ let message = query.message;
+ if message.is_none() {
+ return false;
+ }
+ let message = message.unwrap();
+ let chat_id = message.chat.id;
+
+ let history_message = session_mgr.with_mut_session(chat_id.to_string(), |session| {
+ session.get_history_message(history_msg_id)
+ });
+
+ match history_message {
+ Some(history_message) => {
+ let _ = bot
+ .edit_message_text(chat_id, message.id, history_message.content)
+ .await;
+ }
+ None => {
+ let _ = bot.send_message(chat_id, "The message is stale.").await;
+ }
+ }
+
+ true
+}
+
async fn actually_handle_chat_message(
bot: Bot,
reply_to_msg: Option,
@@ -185,15 +226,65 @@ async fn actually_handle_chat_message(
// Record stats and add the reply to history.
let reply_result = match result {
Ok(res) => {
- session_mgr.add_message_to_session(chat_id.clone(), user_msg);
- session_mgr.add_message_to_session(
- chat_id.clone(),
- ChatCompletionRequestMessageArgs::default()
- .role(Role::Assistant)
- .content(res.content)
- .build()
- .unwrap(),
- );
+ let reply_history_message = session_mgr.with_mut_session(chat_id.clone(), |session| {
+ session.prepare_history_message(
+ ChatCompletionRequestMessageArgs::default()
+ .role(Role::Assistant)
+ .content(&res.content)
+ .build()
+ .unwrap(),
+ )
+ });
+
+ let need_fallback = if config.renders_markdown {
+ let parsed_content = markdown::parse(&res.content);
+ #[cfg(debug_assertions)]
+ {
+ debug!(
+ "rendered Markdown contents: {}\ninto: {:#?}",
+ res.content, parsed_content
+ );
+ }
+ let mut edit_message_text = bot.edit_message_text(
+ chat_id.to_owned(),
+ sent_progress_msg.id,
+ parsed_content.content,
+ );
+ if !parsed_content.entities.is_empty() {
+ let show_raw_button = InlineKeyboardButton::callback(
+ "Show Raw Contents",
+ format!("/show_raw:{}", reply_history_message.id),
+ );
+ edit_message_text.entities = Some(parsed_content.entities);
+ edit_message_text.reply_markup =
+ Some(InlineKeyboardMarkup::default().append_row([show_raw_button]));
+ }
+ if let Err(first_trial_err) = edit_message_text.await {
+ // TODO: test if the error is related to Markdown before
+ // fallback to raw contents.
+ error!(
+ "failed to send message (will fallback to raw contents): {}",
+ first_trial_err
+ );
+ true
+ } else {
+ false
+ }
+ } else {
+ true
+ };
+
+ if need_fallback {
+ bot.edit_message_text(chat_id.to_owned(), sent_progress_msg.id, &res.content)
+ .await?;
+ }
+
+ session_mgr.with_mut_session(chat_id.clone(), |session| {
+ let user_history_msg = session.prepare_history_message(user_msg);
+ session.add_history_message(user_history_msg);
+ session.add_history_message(reply_history_message);
+ });
+
// TODO: maybe we need to handle the case that `reply_to_msg` is `None`.
if let Some(from_username) = reply_to_msg
.as_ref()
@@ -289,8 +380,7 @@ async fn stream_model_result(
// in stream mode. Therefore we need to estimate it locally.
last_response.token_usage =
openai_client::estimate_tokens(&last_response.content) + estimated_prompt_tokens;
- bot.edit_message_text(chat_id.to_owned(), editing_msg.id, &last_response.content)
- .await?;
+
return Ok(last_response);
}
@@ -337,8 +427,8 @@ impl Module for Chat {
)
.branch(
Update::filter_callback_query()
- .filter_async(handle_retry_action)
- .endpoint(noop_handler),
+ .branch(dptree::filter_async(handle_retry_action).endpoint(noop_handler))
+ .branch(dptree::filter_async(handle_show_raw_action).endpoint(noop_handler)),
)
}
diff --git a/src/modules/chat/session.rs b/src/modules/chat/session.rs
index b92bdc7..dba5638 100644
--- a/src/modules/chat/session.rs
+++ b/src/modules/chat/session.rs
@@ -1,13 +1,64 @@
-use std::collections::VecDeque;
+use std::collections::{HashMap, VecDeque};
use async_openai::types::{ChatCompletionRequestMessage as Message, Role};
use crate::config::SharedConfig;
+#[derive(Debug, Clone)]
+pub struct HistoryMessage {
+ pub id: i64,
+ pub message: Message,
+}
+
+#[derive(Debug, Default)]
+struct HistoryMessagePool {
+ current_id: i64,
+ messages: HashMap,
+ deque: VecDeque,
+}
+
+impl HistoryMessagePool {
+ fn prepare_message(&mut self, message: Message) -> HistoryMessage {
+ let (id, _) = self.current_id.overflowing_add(1);
+ self.current_id = id;
+
+ HistoryMessage { id, message }
+ }
+
+ fn push_message(&mut self, message: HistoryMessage) {
+ let id = message.id;
+ self.messages.insert(id, message);
+ self.deque.push_back(id);
+ }
+
+ fn pop_message(&mut self) {
+ if let Some(evicted_id) = self.deque.pop_front() {
+ self.messages.remove(&evicted_id);
+ }
+ }
+
+ fn clear(&mut self) {
+ self.deque.clear();
+ self.messages.clear();
+ }
+
+ fn len(&self) -> usize {
+ self.deque.len()
+ }
+
+ fn get_message(&self, id: &i64) -> Option<&HistoryMessage> {
+ self.messages.get(id)
+ }
+
+ fn iter(&self) -> impl Iterator- + '_ {
+ self.deque.iter().filter_map(|id| self.messages.get(id))
+ }
+}
+
#[derive(Debug)]
pub struct Session {
system_message: Option,
- messages: VecDeque,
+ history_messages: HistoryMessagePool,
pending_message: Option,
config: SharedConfig,
}
@@ -16,7 +67,7 @@ impl Session {
pub fn new(config: SharedConfig) -> Self {
Self {
system_message: None,
- messages: VecDeque::with_capacity(6),
+ history_messages: Default::default(),
pending_message: None,
config,
}
@@ -24,26 +75,36 @@ impl Session {
pub fn reset(&mut self) {
self.system_message = None;
- self.messages.clear();
+ self.history_messages.clear();
self.pending_message = None;
}
- pub fn add_message(&mut self, msg: Message) {
- if matches!(msg.role, Role::System) {
+ pub fn prepare_history_message(&mut self, message: Message) -> HistoryMessage {
+ self.history_messages.prepare_message(message)
+ }
+
+ pub fn add_history_message(&mut self, message: HistoryMessage) {
+ if matches!(message.message.role, Role::System) {
// Replace the previous system message, we only support
// one system message at the same time.
- self.system_message = Some(msg);
+ self.system_message = Some(message.message);
return;
}
- if self.messages.len() >= (self.config.conversation_limit as usize) {
- self.messages.pop_front();
+ if self.history_messages.len() >= (self.config.conversation_limit as usize) {
+ self.history_messages.pop_message();
}
- self.messages.push_back(msg);
+ self.history_messages.push_message(message);
+ }
+
+ pub fn get_history_message(&self, id: i64) -> Option {
+ self.history_messages
+ .get_message(&id)
+ .map(|m| m.message.clone())
}
pub fn get_history_messages(&self) -> Vec {
- let msg_iter = self.messages.iter().cloned();
+ let msg_iter = self.history_messages.iter().map(|m| m.message.clone());
if let Some(sys_msg) = &self.system_message {
let prepend = [sys_msg.to_owned()];
prepend.into_iter().chain(msg_iter).collect()
diff --git a/src/modules/chat/session_mgr.rs b/src/modules/chat/session_mgr.rs
index 8d4a349..4f086c9 100644
--- a/src/modules/chat/session_mgr.rs
+++ b/src/modules/chat/session_mgr.rs
@@ -31,10 +31,6 @@ impl SessionManager {
self.with_mut_session(key, |session| session.reset());
}
- pub fn add_message_to_session(&self, key: String, msg: Message) {
- self.with_mut_session(key, |session| session.add_message(msg));
- }
-
pub fn get_history_messages(&self, key: &str) -> Vec {
self.with_mut_inner(|inner| {
inner
@@ -53,15 +49,7 @@ impl SessionManager {
self.with_mut_session(key, |session| session.swap_pending_message(msg))
}
- fn with_mut_inner(&self, f: F) -> R
- where
- F: FnOnce(&mut SessionManagerInner) -> R,
- {
- let mut inner_mut = self.inner.lock().unwrap();
- f(&mut inner_mut)
- }
-
- fn with_mut_session(&self, key: String, f: F) -> R
+ pub fn with_mut_session(&self, key: String, f: F) -> R
where
F: FnOnce(&mut Session) -> R,
{
@@ -73,6 +61,14 @@ impl SessionManager {
f(session_mut)
})
}
+
+ fn with_mut_inner(&self, f: F) -> R
+ where
+ F: FnOnce(&mut SessionManagerInner) -> R,
+ {
+ let mut inner_mut = self.inner.lock().unwrap();
+ f(&mut inner_mut)
+ }
}
impl Clone for SessionManager {