Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: pass mid-connection WS events to task #269

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions examples/serenity/voice_receive/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,10 @@ async fn join(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
.expect("Songbird Voice client placed in at initialisation.")
.clone();

if let Ok(handler_lock) = manager.join(guild_id, connect_to).await {
// NOTE: this skips listening for the actual connection result.
// Some events relating to voice receive fire *while joining*.
// We must make sure that any event handlers are installed before we attempt to join.
{
let handler_lock = manager.get_or_insert(guild_id);
let mut handler = handler_lock.lock().await;

let evt_receiver = Receiver::new();
Expand All @@ -262,13 +264,18 @@ async fn join(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
handler.add_global_event(CoreEvent::RtcpPacket.into(), evt_receiver.clone());
handler.add_global_event(CoreEvent::ClientDisconnect.into(), evt_receiver.clone());
handler.add_global_event(CoreEvent::VoiceTick.into(), evt_receiver);
}

if let Ok(handler_lock) = manager.join(guild_id, connect_to).await {
check_msg(
msg.channel_id
.say(&ctx.http, &format!("Joined {}", connect_to.mention()))
.await,
);
} else {
// Although we failed to join, we need to clear out existing event handlers on the call.
_ = manager.remove(guild_id).await;

check_msg(
msg.channel_id
.say(&ctx.http, "Error joining the channel")
Expand Down
24 changes: 15 additions & 9 deletions src/driver/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl Connection {
let url = generate_url(&mut info.endpoint)?;

let mut client = WsStream::connect(url).await?;
let (ws_msg_tx, ws_msg_rx) = flume::unbounded();

let mut hello = None;
let mut ready = None;
Expand Down Expand Up @@ -93,7 +94,11 @@ impl Connection {
}
},
other => {
// Discord hold back per-user connection state until after this handshake.
// There's no guarantee that will remain the case, so buffer it like all
// subsequent steps where we know they *do* send these packets.
debug!("Expected ready/hello; got: {:?}", other);
ws_msg_tx.send(WsMessage::Deliver(other))?;
},
}
}
Expand Down Expand Up @@ -176,13 +181,12 @@ impl Connection {
.await?;
}

let cipher = init_cipher(&mut client, chosen_crypto).await?;
let cipher = init_cipher(&mut client, chosen_crypto, &ws_msg_tx).await?;

info!("Connected to: {}", info.endpoint);

info!("WS heartbeat duration {}ms.", hello.heartbeat_interval);

let (ws_msg_tx, ws_msg_rx) = flume::unbounded();
#[cfg(feature = "receive")]
let (udp_receiver_msg_tx, udp_receiver_msg_rx) = flume::unbounded();

Expand Down Expand Up @@ -304,7 +308,7 @@ impl Connection {
}
},
other => {
debug!("Expected resumed/hello; got: {:?}", other);
self.ws.send(WsMessage::Deliver(other))?;
},
}
}
Expand Down Expand Up @@ -338,7 +342,11 @@ fn generate_url(endpoint: &mut String) -> Result<Url> {
}

#[inline]
async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result<Cipher> {
async fn init_cipher(
client: &mut WsStream,
mode: CryptoMode,
tx: &Sender<WsMessage>,
) -> Result<Cipher> {
loop {
let Some(value) = client.recv_json().await? else {
continue;
Expand All @@ -355,11 +363,9 @@ async fn init_cipher(client: &mut WsStream, mode: CryptoMode) -> Result<Cipher>
.map_err(|_| Error::CryptoInvalidLength);
},
other => {
debug!(
"Expected ready for key; got: op{}/v{:?}",
other.kind() as u8,
other
);
// Discord can and will send user-specific payload packets during this time
// which are needed to map SSRCs to `UserId`s.
tx.send(WsMessage::Deliver(other))?;
},
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/driver/tasks/message/ws.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#![allow(missing_docs)]

use super::Interconnect;
use crate::ws::WsStream;
use crate::{model::Event as GatewayEvent, ws::WsStream};

pub enum WsMessage {
Ws(Box<WsStream>),
ReplaceInterconnect(Interconnect),
SetKeepalive(f64),
Speaking(bool),
Deliver(GatewayEvent),
}
3 changes: 3 additions & 0 deletions src/driver/tasks/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ impl AuxNetwork {
}
}
},
Ok(WsMessage::Deliver(msg)) => {
self.process_ws(interconnect, msg);
},
Err(flume::RecvError::Disconnected) => {
break;
},
Expand Down
Loading