diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a7dc405..d54fa2f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -18,5 +18,5 @@ jobs: - uses: actions/checkout@v3 - name: Build run: cargo build --verbose - #- name: Run tests - # run: cargo test --verbose + - name: Run tests + run: cargo test --verbose diff --git a/Cargo.lock b/Cargo.lock index 0287bfc..89d3477 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "aquamarine" version = "0.1.12" @@ -185,7 +200,12 @@ version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", + "windows-targets 0.52.5", ] [[package]] @@ -773,6 +793,29 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "1.5.0" @@ -984,9 +1027,9 @@ checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "6d0d8b92cd8358e8d229c11df9358decae64d137c5be540952c5ca7b25aea768" [[package]] name = "mime" @@ -1785,9 +1828,10 @@ checksum = "20f34339676cdcab560c9a82300c4c2581f68b9369aedf0fae86f2ff9565ff3e" [[package]] name = "telitairos-bot" -version = "0.1.2" +version = "0.2.0" dependencies = [ "async-openai", + "chrono", "log", "pretty_env_logger", "string-builder", @@ -2254,6 +2298,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.5", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/Cargo.toml b/Cargo.toml index ac53f45..2c57362 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "telitairos-bot" description = "A fully funcional AI Powered assistant Telegram Bot" -version = "0.1.2" +version = "0.2.0" edition = "2021" license = "MIT" keywords = ["ai", "telegram", "bot", "admin", "chatbot"] @@ -10,9 +10,9 @@ exclude = ["src/main.rs"] [dependencies] async-openai = "0.23.2" - teloxide = { version = "0.12", features = ["macros"] } log = "0.4" pretty_env_logger = "0.4" tokio = { version = "1.38", features = ["rt-multi-thread", "macros"] } string-builder = "0.2.0" +chrono = "0.4.38" diff --git a/readme.md b/readme.md index 5c96670..bab2329 100644 --- a/readme.md +++ b/readme.md @@ -38,7 +38,12 @@ tokio = { version = "1.8", features = ["rt-multi-thread", "macros"] } ``` ## Supported commands -You can either: +You can do either: + +### 👮🚨 ADMIN Commands +- `/mute X {h/m/s/p}` -> Mute an User from the Chat Group the selected time. 'p' is for 'permanent' +- `/ban X {h/m/s/p}` -> Ban an User from the Chat Group the selected time. 'p' is for 'permanent' +### 🦀 AI Commands - `/ask` for a specified question. - `/mediate` to read the last N messages of a chat group and mitigate an argument. diff --git a/src/bot.rs b/src/bot.rs deleted file mode 100644 index 5fd900e..0000000 --- a/src/bot.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::collections::VecDeque; - -use crate::{gpt, types, TelitairoBot}; -use teloxide::{prelude::*, utils::command::BotCommands}; - -#[derive(BotCommands, Clone)] -#[command( - rename_rule = "lowercase", - description = "These commands are supported" -)] -pub enum Command { - #[command(description = "Display this text.")] - Help, - - #[command(description = "Ask the bot a question.")] - Ask(String), - - #[command(description = "Ask the bot to mediate a discussion")] - Mediate, -} - -pub async fn handle_commands( - bot: Bot, - buffer_store: types::BufferStore, - telitairo_bot: TelitairoBot, - msg: Message, - cmd: Command, -) -> ResponseResult<()> { - match cmd { - Command::Help => { - bot.send_message(msg.chat.id, Command::descriptions().to_string()) - .await?; - } - Command::Ask(question) => { - let answer = match gpt::ask(question, telitairo_bot).await { - Ok(response) => response, - Err(err) => format!("Error getting an answer from OpenAI: {err}"), - }; - - bot.send_message(msg.chat.id, answer).await?; - } - Command::Mediate => { - let answer = match gpt::mediate(buffer_store, telitairo_bot, msg.chat.id).await { - Ok(response) => response, - Err(err) => format!("Error getting an answer from OpenAI: {err}"), - }; - - bot.send_message(msg.chat.id, answer).await?; - } - }; - - Ok(()) -} - -pub async fn handle_messages( - buffer_store: types::BufferStore, - telitairo_bot: TelitairoBot, - msg: Message, -) -> ResponseResult<()> { - let mut buffer_store_lock = buffer_store.write().await; - match buffer_store_lock.get_mut(&msg.chat.id) { - Some(buffer) => { - if buffer.len() == telitairo_bot.buffer_size { - buffer.pop_front(); - } - buffer.push_back(msg.clone()); - } - None => { - let mut buffer = VecDeque::new(); - buffer.push_back(msg.clone()); - buffer_store_lock.insert(msg.chat.id, buffer); - } - } - - Ok(()) -} diff --git a/src/bot/admin.rs b/src/bot/admin.rs new file mode 100644 index 0000000..24b9890 --- /dev/null +++ b/src/bot/admin.rs @@ -0,0 +1,152 @@ +use crate::bot::*; +use chrono::Duration; +use teloxide::{ + payloads::RestrictChatMemberSetters, + types::{ChatPermissions, ParseMode}, +}; + +#[derive(BotCommands, Clone, PartialEq)] +#[command( + rename_rule = "lowercase", + description = "Supported ADMIN Commands", + parse_with = "split" +)] +pub enum AdminCommand { + #[command(description = "Display this text\\.")] + Help, + + #[command( + description = "`/mute X {h/m/s/p}` \\-\\> Mute an User from the Chat Group the selected time\\. 'p' is for 'permanent'" + )] + Mute(types::TimeAmount, types::UnitOfTime), + + #[command( + description = "`/ban X {h/m/s/p}` \\-\\> Ban an User from the Chat Group the selected time\\. 'p' is for 'permanent'" + )] + Ban(types::TimeAmount, types::UnitOfTime), +} + +pub async fn handle_admin_commands( + bot: Bot, + msg: Message, + cmd: AdminCommand, +) -> ResponseResult<()> { + match cmd { + AdminCommand::Help => { + bot.send_message(msg.chat.id, all_command_descriptions()) + .parse_mode(ParseMode::MarkdownV2) + .await?; + } + AdminCommand::Mute(time_amount, unit_of_time) => { + mute_user(bot, msg, calc_time(time_amount, unit_of_time)).await?; + } + AdminCommand::Ban(time_amount, unit_of_time) => { + ban_user(bot, msg, calc_time(time_amount, unit_of_time)).await?; + } + }; + + Ok(()) +} + +async fn mute_user(bot: Bot, msg: Message, time: Option) -> ResponseResult<()> { + let duration = match time { + Some(d) => d, + None => { + bot.send_message(msg.chat.id, "Send a properly formatted time span") + .await?; + return Ok(()); + } + }; + + match msg.reply_to_message() { + Some(replied) => { + bot.restrict_chat_member( + msg.chat.id, + replied.from().expect("Must be MessageKind::Common").id, + ChatPermissions::empty(), + ) + .until_date(msg.date + duration) + .await?; + } + None => { + bot.send_message( + msg.chat.id, + "Use this command in a reply to another message!", + ) + .await?; + } + } + + Ok(()) +} + +async fn ban_user(bot: Bot, msg: Message, time: Option) -> ResponseResult<()> { + let duration = match time { + Some(d) => d, + None => { + bot.send_message(msg.chat.id, "Send a properly formatted time span") + .await?; + return Ok(()); + } + }; + + match msg.reply_to_message() { + Some(replied) => { + bot.kick_chat_member( + msg.chat.id, + replied.from().expect("Must be MessageKind::Common").id, + ) + .until_date(msg.date + duration) + .await?; + } + None => { + bot.send_message( + msg.chat.id, + "Use this command in a reply to another message!", + ) + .await?; + } + } + + Ok(()) +} + +fn calc_time(time_amount: types::TimeAmount, unit_of_time: types::UnitOfTime) -> Option { + match unit_of_time { + types::UnitOfTime::Seconds => Duration::try_seconds(time_amount.into()), + types::UnitOfTime::Minutes => Duration::try_minutes(time_amount.into()), + types::UnitOfTime::Hours => Duration::try_hours(time_amount.into()), + types::UnitOfTime::Permanent => Some(Duration::max_value()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::*; + use chrono::Duration; + + #[test] + fn calc_time_seconds() { + let result = calc_time(10, UnitOfTime::Seconds); + assert_eq!(result, Some(Duration::seconds(10))); + } + + #[test] + fn calc_time_minutes() { + let result = calc_time(10, UnitOfTime::Minutes); + assert_eq!(result, Some(Duration::seconds(600))); + } + + #[test] + fn calc_time_hours() { + let result = calc_time(2, UnitOfTime::Hours); + assert_eq!(result, Some(Duration::seconds(7200))); + } + + #[test] + fn calc_time_permanent() { + let result = calc_time(0, UnitOfTime::Permanent); + assert_eq!(result, Some(Duration::max_value())); + } +} diff --git a/src/bot/ai.rs b/src/bot/ai.rs new file mode 100644 index 0000000..2aed7a8 --- /dev/null +++ b/src/bot/ai.rs @@ -0,0 +1,39 @@ +use crate::*; + +#[derive(BotCommands, Clone)] +#[command(rename_rule = "lowercase", description = "Supported AI Commands")] +pub enum AiCommand { + #[command(description = "Ask the bot a question")] + Ask(String), + + #[command(description = "Ask the bot to mediate a discussion")] + Mediate, +} + +pub async fn handle_ai_commands( + bot: Bot, + buffer_store: types::BufferStore, + telitairo_bot: TelitairoBot, + msg: Message, + cmd: AiCommand, +) -> ResponseResult<()> { + match cmd { + AiCommand::Ask(question) => { + let answer = match gpt::ask(question, telitairo_bot).await { + Ok(response) => response, + Err(err) => format!("Error getting an answer from OpenAI: {err}"), + }; + bot.send_message(msg.chat.id, answer).await?; + } + AiCommand::Mediate => { + let answer = match gpt::mediate(buffer_store, telitairo_bot, msg.chat.id).await { + Ok(response) => response, + Err(err) => format!("Error getting an answer from OpenAI: {err}"), + }; + + bot.send_message(msg.chat.id, answer).await?; + } + }; + + Ok(()) +} diff --git a/src/bot/mod.rs b/src/bot/mod.rs new file mode 100644 index 0000000..31f0d31 --- /dev/null +++ b/src/bot/mod.rs @@ -0,0 +1,58 @@ +pub mod admin; +pub mod ai; + +use crate::{types, TelitairoBot}; +use std::collections::VecDeque; +pub use teloxide::{prelude::*, utils::command::BotCommands}; + +pub async fn handle_messages( + buffer_store: types::BufferStore, + telitairo_bot: TelitairoBot, + msg: Message, +) -> ResponseResult<()> { + let mut buffer_store_lock = buffer_store.write().await; + match buffer_store_lock.get_mut(&msg.chat.id) { + Some(buffer) => { + if buffer.len() == telitairo_bot.buffer_size { + buffer.pop_front(); + } + buffer.push_back(msg.clone()); + } + None => { + let mut buffer = VecDeque::new(); + buffer.push_back(msg.clone()); + buffer_store_lock.insert(msg.chat.id, buffer); + } + } + + Ok(()) +} + +pub fn all_command_descriptions() -> String { + let admin_command_descriptions = admin::AdminCommand::descriptions(); + let ai_command_descriptions = ai::AiCommand::descriptions(); + + format!("👮🚨{admin_command_descriptions}\n\n\n🦀 🤖{ai_command_descriptions}") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_command_descriptions(){ + assert_eq!(all_command_descriptions(), +"👮🚨Supported ADMIN Commands + +/help — Display this text\\. +/mute — `/mute X {h/m/s/p}` \\-\\> Mute an User from the Chat Group the selected time\\. 'p' is for 'permanent' +/ban — `/ban X {h/m/s/p}` \\-\\> Ban an User from the Chat Group the selected time\\. 'p' is for 'permanent' + + +🦀 🤖Supported AI Commands + +/ask — Ask the bot a question +/mediate — Ask the bot to mediate a discussion" +) + } +} diff --git a/src/lib.rs b/src/lib.rs index 4c938d2..a15320f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,8 @@ //! - `/mediate` to read the last N messages of a chat group and mitigate an argument. //! //! ## Environment variables needed -//! ``` +//! +//! ```bash //! - TELOXIDE_TOKEN= "/* Your Telegram Bot API Key */" //! - OPENAI_API_KEY= "/* Your OpenAI API Key */" //! - OPENAI_ORG_ID= "/* Your OpenAI Organization ID */" @@ -25,50 +26,51 @@ //! by just doing this: //! //! ``` +//! # use telitairos_bot::TelitairoBot; //! let telitairo_bot: TelitairoBot = Default::default(); //! ``` //! //! But if you want to set your own Bot's personality you can use the `new()` function like this: //! //! ``` -//! #[tokio::main] -//! async fn main() { -//! pretty_env_logger::init(); -//! log::info!("Starting bot"); +//! # use telitairos_bot::TelitairoBot; +//! /* #[tokio::main] */ +//! /*async*/ fn main() { +//! let telitairo_bot = TelitairoBot::new( +//! String::from("personality"), +//! String::from("mediation criteria"), +//! 200 /* buffer size */, +//! ); //! -//! let telitairo_bot = TelitairoBot::new( -//! String::from(/*Personality */), -//! String::from(/* Mediation criteria */), -//! /*size */, -//! ); -//! -//! telitairo_bot.dispatch().await; -//! } +//! /* telitairo_bot.dispatch().await; */ +//! # } +//! ``` //! mod bot; mod gpt; mod types; +use crate::bot::*; use std::collections::HashMap; use std::sync::Arc; -use teloxide::prelude::*; +use teloxide::dispatching::{HandlerExt, UpdateFilterExt}; use teloxide::{dptree, Bot}; use tokio::sync::RwLock; /// Defines the bot behavior -#[derive(Clone)] +#[derive(Clone, Debug, PartialEq)] pub struct TelitairoBot { /// String to define the bot personality, a descriptive short prompt. /// /// # Example - /// ``` + /// ```bash /// "You are a virtual assistant with a touch of acid humour and you love potatoes" /// ``` pub personality: String, /// String to define the bot action when `/mediate` command is sent. descriptive short prompt. /// /// # Example - /// ``` + /// ```bash /// "Take the messages, search for possible discussions and choose one side" /// ``` pub mediate_query: String, @@ -97,13 +99,18 @@ impl TelitairoBot { let bot = Bot::from_env(); let buffer_store: types::BufferStore = Arc::new(RwLock::new(HashMap::new())); - let handler = dptree::entry() + let handler = Update::filter_message() .branch( - Update::filter_message() - .filter_command::() - .endpoint(bot::handle_commands), + dptree::entry() + .filter_command::() + .endpoint(ai::handle_ai_commands), ) - .branch(Update::filter_message().endpoint(bot::handle_messages)); + .branch( + dptree::entry() + .filter_command::() + .endpoint(admin::handle_admin_commands), + ) + .branch(dptree::endpoint(handle_messages)); Dispatcher::builder(bot, handler) .dependencies(dptree::deps![buffer_store, self.clone()]) @@ -130,3 +137,27 @@ impl Default for TelitairoBot { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_and_new() { + let telitairo_default = TelitairoBot::default(); + let telitairo_partially_default = TelitairoBot { + buffer_size: 200, + ..Default::default() + }; + let telitairo_new = TelitairoBot::new( + String::from( + "You are a virtual assistant with a touch of acid humour and you love potatoes", + ), + String::from("Take the messages, search for possible discussions and choose one side"), + 200, + ); + + assert_eq!(telitairo_default, telitairo_new); + assert_eq!(telitairo_partially_default, telitairo_new); + } +} diff --git a/src/types.rs b/src/types.rs index 473860d..5fa52d0 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,5 +1,6 @@ use std::{ collections::{HashMap, VecDeque}, + str::FromStr, sync::Arc, }; use teloxide::types::{ChatId, Message}; @@ -12,3 +13,44 @@ pub const DEFAULT_MEDIATION_QUERY: &str = pub const DEFAULT_BUFFER_SIZE: usize = 200; pub type BufferStore = Arc>>>; + +pub type TimeAmount = u8; + +#[derive(Clone, Debug, PartialEq)] +pub enum UnitOfTime { + Seconds, + Minutes, + Hours, + Permanent, +} + +impl FromStr for UnitOfTime { + type Err = &'static str; + fn from_str(s: &str) -> Result { + match s { + "h" => Ok(UnitOfTime::Hours), + "m" => Ok(UnitOfTime::Minutes), + "s" => Ok(UnitOfTime::Seconds), + "p" => Ok(UnitOfTime::Permanent), + _ => Err("Allowed units: h, m, s, p"), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn conversion_ok() { + let uot: UnitOfTime = UnitOfTime::from_str("h").expect("Failed to convert"); + assert_eq!(uot, UnitOfTime::Hours); + } + + #[test] + fn conversion_nok() { + let result = UnitOfTime::from_str("x"); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Allowed units: h, m, s, p"); + } +}