diff --git a/examples/advanced_cooldowns/main.rs b/examples/advanced_cooldowns/main.rs index 417f9a79cc1d..0b3fe3c09ac1 100644 --- a/examples/advanced_cooldowns/main.rs +++ b/examples/advanced_cooldowns/main.rs @@ -15,11 +15,11 @@ async fn dynamic_cooldowns(ctx: Context<'_>) -> Result<(), Error> { cooldown_durations.user = Some(std::time::Duration::from_secs(10)); } - match cooldown_tracker.remaining_cooldown(ctx, &cooldown_durations) { + match cooldown_tracker.remaining_cooldown(ctx.cooldown_context(), &cooldown_durations) { Some(remaining) => { return Err(format!("Please wait {} seconds", remaining.as_secs()).into()) } - None => cooldown_tracker.start_cooldown(ctx), + None => cooldown_tracker.start_cooldown(ctx.cooldown_context()), } }; diff --git a/macros/src/command/prefix.rs b/macros/src/command/prefix.rs index 1ace757f4609..681e86ebaa3a 100644 --- a/macros/src/command/prefix.rs +++ b/macros/src/command/prefix.rs @@ -64,7 +64,7 @@ pub fn generate_prefix_action(inv: &Invocation) -> Result Result>::to_action(|ctx, value| { Box::pin(async move { if !ctx.framework.options.manual_cooldowns { - ctx.command.cooldowns.lock().unwrap().start_cooldown(ctx.into()); + ctx.command.cooldowns.lock().unwrap().start_cooldown(ctx.cooldown_context()); } inner(ctx.into(), value) diff --git a/src/cooldown.rs b/src/cooldown.rs index a977a4b41161..0795744c50b1 100644 --- a/src/cooldown.rs +++ b/src/cooldown.rs @@ -5,6 +5,18 @@ use crate::serenity_prelude as serenity; use crate::util::OrderedMap; use std::time::{Duration, Instant}; +/// Subset of [`crate::Context`] so that [`Cooldowns`] can be used without requiring a full [Context](`crate::Context`) +/// (ie from within an `event_handler`) +#[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] +pub struct CooldownContext { + /// The user associated with this request + pub user_id: serenity::UserId, + /// The guild this request originated from or `None` + pub guild_id: Option, + /// The channel associated with this request + pub channel_id: serenity::ChannelId, +} + /// Configuration struct for [`Cooldowns`] #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] pub struct CooldownConfig { @@ -57,24 +69,24 @@ impl CooldownTracker { /// Queries the cooldown buckets and checks if all cooldowns have expired and command /// execution may proceed. If not, Some is returned with the remaining cooldown - pub fn remaining_cooldown( + pub fn remaining_cooldown( &self, - ctx: crate::Context<'_, U, E>, + ctx: CooldownContext, cooldown_durations: &CooldownConfig, ) -> Option { let mut cooldown_data = vec![ (cooldown_durations.global, self.global_invocation), ( cooldown_durations.user, - self.user_invocations.get(&ctx.author().id).copied(), + self.user_invocations.get(&ctx.user_id).copied(), ), ( cooldown_durations.channel, - self.channel_invocations.get(&ctx.channel_id()).copied(), + self.channel_invocations.get(&ctx.channel_id).copied(), ), ]; - if let Some(guild_id) = ctx.guild_id() { + if let Some(guild_id) = ctx.guild_id { cooldown_data.push(( cooldown_durations.guild, self.guild_invocations.get(&guild_id).copied(), @@ -82,7 +94,7 @@ impl CooldownTracker { cooldown_data.push(( cooldown_durations.member, self.member_invocations - .get(&(ctx.author().id, guild_id)) + .get(&(ctx.user_id, guild_id)) .copied(), )); } @@ -98,17 +110,26 @@ impl CooldownTracker { } /// Indicates that a command has been executed and all associated cooldowns should start running - pub fn start_cooldown(&mut self, ctx: crate::Context<'_, U, E>) { + pub fn start_cooldown(&mut self, ctx: CooldownContext) { let now = Instant::now(); self.global_invocation = Some(now); - self.user_invocations.insert(ctx.author().id, now); - self.channel_invocations.insert(ctx.channel_id(), now); + self.user_invocations.insert(ctx.user_id, now); + self.channel_invocations.insert(ctx.channel_id, now); - if let Some(guild_id) = ctx.guild_id() { + if let Some(guild_id) = ctx.guild_id { self.guild_invocations.insert(guild_id, now); - self.member_invocations - .insert((ctx.author().id, guild_id), now); + self.member_invocations.insert((ctx.user_id, guild_id), now); + } + } +} + +impl<'a> From<&'a serenity::Message> for CooldownContext { + fn from(message: &'a serenity::Message) -> Self { + Self { + user_id: message.author.id, + channel_id: message.channel_id, + guild_id: message.guild_id, } } } diff --git a/src/dispatch/common.rs b/src/dispatch/common.rs index 7b4985732ee8..14cc2fa9708f 100644 --- a/src/dispatch/common.rs +++ b/src/dispatch/common.rs @@ -158,9 +158,9 @@ async fn check_permissions_and_cooldown_single<'a, U, E>( } if !ctx.framework().options().manual_cooldowns { - let cooldowns = &cmd.cooldowns; + let cooldowns = cmd.cooldowns.lock().unwrap(); let config = cmd.cooldown_config.read().unwrap(); - let remaining_cooldown = cooldowns.lock().unwrap().remaining_cooldown(ctx, &config); + let remaining_cooldown = cooldowns.remaining_cooldown(ctx.cooldown_context(), &config); if let Some(remaining_cooldown) = remaining_cooldown { return Err(crate::FrameworkError::CooldownHit { ctx, diff --git a/src/structs/context.rs b/src/structs/context.rs index f886f92ea595..759df912568e 100644 --- a/src/structs/context.rs +++ b/src/structs/context.rs @@ -171,6 +171,16 @@ context_methods! { } } + /// Create a [`crate::CooldownContext`] based off the underlying context type. + (cooldown_context self) + (pub fn cooldown_context(self) -> crate::CooldownContext) { + crate::CooldownContext { + user_id: self.author().id, + channel_id: self.channel_id(), + guild_id: self.guild_id() + } + } + /// See [`Self::serenity_context`]. #[deprecated = "poise::Context can now be passed directly into most serenity functions. Otherwise, use `.serenity_context()` now"] #[allow(deprecated)]