From 41ae54063f69c82fe61aca007fe3f7ea568a34a2 Mon Sep 17 00:00:00 2001 From: Kyle Simpson Date: Tue, 26 Nov 2024 00:08:23 +0000 Subject: [PATCH] Fix: pass mid-connection WS events to task Discord will send `GatewayEvent::Speaking` (opcode 5) messages after the Hello+Ready exchange, but will happily interleave them with crypto mode negotiation. We were previously not expecting such messages and dropping them -- this hurts receive-based bots' ability to map SSRCs to UserIds when joining a call with existing users. This PR feeds all unexpected messages into the WS task directly, which will handle them once all tasks are fully started. --- examples/serenity/voice_receive/src/main.rs | 11 ++++++++-- src/driver/connection/mod.rs | 24 +++++++++++++-------- src/driver/tasks/message/ws.rs | 3 ++- src/driver/tasks/ws.rs | 3 +++ 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/examples/serenity/voice_receive/src/main.rs b/examples/serenity/voice_receive/src/main.rs index f3fa04eeb..4d20d4fad 100644 --- a/examples/serenity/voice_receive/src/main.rs +++ b/examples/serenity/voice_receive/src/main.rs @@ -251,8 +251,10 @@ async fn join(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { .expect("Songbird Voice client placed in at initialisation.") .clone(); - if let Ok(handler_lock) = manager.join(guild_id, connect_to).await { - // NOTE: this skips listening for the actual connection result. + // Some events relating to voice receive fire *while joining*. + // We must make sure that any event handlers are installed before we attempt to join. + { + let handler_lock = manager.get_or_insert(guild_id); let mut handler = handler_lock.lock().await; let evt_receiver = Receiver::new(); @@ -262,13 +264,18 @@ async fn join(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { handler.add_global_event(CoreEvent::RtcpPacket.into(), evt_receiver.clone()); handler.add_global_event(CoreEvent::ClientDisconnect.into(), evt_receiver.clone()); handler.add_global_event(CoreEvent::VoiceTick.into(), evt_receiver); + } + if let Ok(handler_lock) = manager.join(guild_id, connect_to).await { check_msg( msg.channel_id .say(&ctx.http, &format!("Joined {}", connect_to.mention())) .await, ); } else { + // Although we failed to join, we need to clear out existing event handlers on the call. + _ = manager.remove(guild_id).await; + check_msg( msg.channel_id .say(&ctx.http, "Error joining the channel") diff --git a/src/driver/connection/mod.rs b/src/driver/connection/mod.rs index 41883c058..de2fca1eb 100644 --- a/src/driver/connection/mod.rs +++ b/src/driver/connection/mod.rs @@ -61,6 +61,7 @@ impl Connection { let url = generate_url(&mut info.endpoint)?; let mut client = WsStream::connect(url).await?; + let (ws_msg_tx, ws_msg_rx) = flume::unbounded(); let mut hello = None; let mut ready = None; @@ -93,7 +94,11 @@ impl Connection { } }, other => { + // Discord hold back per-user connection state until after this handshake. + // There's no guarantee that will remain the case, so buffer it like all + // subsequent steps where we know they *do* send these packets. debug!("Expected ready/hello; got: {:?}", other); + ws_msg_tx.send(WsMessage::Deliver(other))?; }, } } @@ -176,13 +181,12 @@ impl Connection { .await?; } - let cipher = init_cipher(&mut client, chosen_crypto).await?; + let cipher = init_cipher(&mut client, chosen_crypto, &ws_msg_tx).await?; info!("Connected to: {}", info.endpoint); info!("WS heartbeat duration {}ms.", hello.heartbeat_interval); - let (ws_msg_tx, ws_msg_rx) = flume::unbounded(); #[cfg(feature = "receive")] let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded(); @@ -304,7 +308,7 @@ impl Connection { } }, other => { - debug!("Expected resumed/hello; got: {:?}", other); + self.ws.send(WsMessage::Deliver(other))?; }, } } @@ -338,7 +342,11 @@ fn generate_url(endpoint: &mut String) -> Result { } #[inline] -async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result { +async fn init_cipher( + client: &mut WsStream, + mode: CryptoMode, + tx: &Sender, +) -> Result { loop { let Some(value) = client.recv_json().await? else { continue; @@ -355,11 +363,9 @@ async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result .map_err(|_| Error::CryptoInvalidLength); }, other => { - debug!( - "Expected ready for key; got: op{}/v{:?}", - other.kind() as u8, - other - ); + // Discord can and will send user-specific payload packets during this time + // which are needed to map SSRCs to `UserId`s. + tx.send(WsMessage::Deliver(other))?; }, } } diff --git a/src/driver/tasks/message/ws.rs b/src/driver/tasks/message/ws.rs index 4faf68367..77a9d7a9d 100644 --- a/src/driver/tasks/message/ws.rs +++ b/src/driver/tasks/message/ws.rs @@ -1,11 +1,12 @@ #![allow(missing_docs)] use super::Interconnect; -use crate::ws::WsStream; +use crate::{model::Event as GatewayEvent, ws::WsStream}; pub enum WsMessage { Ws(Box), ReplaceInterconnect(Interconnect), SetKeepalive(f64), Speaking(bool), + Deliver(GatewayEvent), } diff --git a/src/driver/tasks/ws.rs b/src/driver/tasks/ws.rs index 167a72d03..449cd9791 100644 --- a/src/driver/tasks/ws.rs +++ b/src/driver/tasks/ws.rs @@ -145,6 +145,9 @@ impl AuxNetwork { } } }, + Ok(WsMessage::Deliver(msg)) => { + self.process_ws(interconnect, msg); + }, Err(flume::RecvError::Disconnected) => { break; },