Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reuse websockets #203

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions presage/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,27 @@ impl<T: Clone> Default for CacheCell<T> {
}

impl<T: Clone> CacheCell<T> {
pub fn get<E>(&self, factory: impl FnOnce() -> Result<T, E>) -> Result<T, E> {
pub fn get(&self, factory: impl FnOnce() -> T) -> T {
let value = match self.cell.replace(None) {
Some(value) => value,
None => factory()?,
None => factory(),
};
self.cell.set(Some(value.clone()));
Ok(value)
value
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::convert::Infallible;

#[test]
fn test_cache_cell() {
let cache: CacheCell<String> = Default::default();

let value = cache
.get(|| Ok::<_, Infallible>("Hello, World!".to_string()))
.unwrap();
assert_eq!(value, "Hello, World!");
let value = cache
.get(|| -> Result<String, Infallible> { panic!("I should not run") })
.unwrap();
let value = cache.get(|| ("Hello, World!".to_string()));
assert_eq!(value, "Hello, World!");

let value = cache
.get(|| -> Result<String, Infallible> { panic!("I should not run") })
.unwrap();
let value = cache.get(|| panic!("I should not run"));
assert_eq!(value, "Hello, World!");
}
}
43 changes: 18 additions & 25 deletions presage/src/manager/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use log::trace;
use rand::rngs::StdRng;
use rand::{RngCore, SeedableRng};

use crate::cache::CacheCell;
use crate::manager::registered::RegistrationData;
use crate::store::Store;
use crate::{Error, Manager};
Expand Down Expand Up @@ -122,31 +121,25 @@ impl<S: Store> Manager<S, Confirmation> {
let mut manager = Manager {
rng,
store: self.store,
state: Registered {
push_service_cache: CacheCell::default(),
identified_websocket: Default::default(),
unidentified_websocket: Default::default(),
unidentified_sender_certificate: Default::default(),
data: RegistrationData {
signal_servers: self.state.signal_servers,
device_name: None,
phone_number,
service_ids: ServiceIds {
aci: registered.uuid,
pni: registered.pni,
},
password,
signaling_key,
device_id: None,
registration_id,
pni_registration_id: Some(pni_registration_id),
aci_private_key: aci_identity_key_pair.private_key,
aci_public_key: aci_identity_key_pair.public_key,
pni_private_key: Some(pni_identity_key_pair.private_key),
pni_public_key: Some(pni_identity_key_pair.public_key),
profile_key,
state: Registered::with_data(RegistrationData {
signal_servers: self.state.signal_servers,
device_name: None,
phone_number,
service_ids: ServiceIds {
aci: registered.uuid,
pni: registered.pni,
},
},
password,
signaling_key,
device_id: None,
registration_id,
pni_registration_id: Some(pni_registration_id),
aci_private_key: aci_identity_key_pair.private_key,
aci_public_key: aci_identity_key_pair.public_key,
pni_private_key: Some(pni_identity_key_pair.private_key),
pni_public_key: Some(pni_identity_key_pair.public_key),
profile_key,
}),
};

manager.store.save_registration_data(&manager.state.data)?;
Expand Down
154 changes: 93 additions & 61 deletions presage/src/manager/registered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
use libsignal_service::configuration::{ServiceConfiguration, SignalServers, SignalingKey};
use libsignal_service::content::{Content, ContentBody, DataMessageFlags, Metadata};
use libsignal_service::groups_v2::{decrypt_group, Group, GroupsManager, InMemoryCredentialsCache};
use libsignal_service::messagepipe::{Incoming, ServiceCredentials};
use libsignal_service::messagepipe::{Incoming, MessagePipe, ServiceCredentials};
use libsignal_service::models::Contact;
use libsignal_service::prelude::phonenumber::PhoneNumber;
use libsignal_service::prelude::Uuid;
Expand Down Expand Up @@ -53,7 +53,8 @@
/// Manager state when the client is registered and can send and receive messages from Signal
#[derive(Clone)]
pub struct Registered {
pub(crate) push_service_cache: CacheCell<HyperPushService>,
pub(crate) identified_push_service: CacheCell<HyperPushService>,
pub(crate) unidentified_push_service: CacheCell<HyperPushService>,
pub(crate) identified_websocket: Arc<Mutex<Option<SignalWebSocket>>>,
pub(crate) unidentified_websocket: Arc<Mutex<Option<SignalWebSocket>>>,
pub(crate) unidentified_sender_certificate: Option<SenderCertificate>,
Expand All @@ -70,7 +71,8 @@
impl Registered {
pub(crate) fn with_data(data: RegistrationData) -> Self {
Self {
push_service_cache: CacheCell::default(),
identified_push_service: CacheCell::default(),
unidentified_push_service: CacheCell::default(),
identified_websocket: Default::default(),
unidentified_websocket: Default::default(),
unidentified_sender_certificate: Default::default(),
Expand Down Expand Up @@ -172,10 +174,76 @@
&self.state.data
}

/// Returns a clone of a cached push service (with credentials).
///
/// If no service is yet cached, it will create and cache one.
fn identified_push_service(&self) -> HyperPushService {
self.state.identified_push_service.get(|| {
HyperPushService::new(
self.state.service_configuration(),
self.credentials(),
crate::USER_AGENT.to_string(),
)
})
}

/// Returns a clone of a cached push service (without credentials).
///
/// If no service is yet cached, it will create and cache one.
fn unidentified_push_service(&self) -> HyperPushService {
self.state.unidentified_push_service.get(|| {
HyperPushService::new(
self.state.service_configuration(),
None,
crate::USER_AGENT.to_string(),
)
})
}

/// Returns the current identified websocket, or creates a new one
async fn identified_websocket(&self) -> Result<SignalWebSocket, Error<S::Error>> {
let mut identified_ws = self.state.identified_websocket.lock();

Check warning on line 205 in presage/src/manager/registered.rs

View workflow job for this annotation

GitHub Actions / clippy

this `MutexGuard` is held across an `await` point
match identified_ws.clone() {
Some(ws) => Ok(ws),
None => {
let keep_alive = true;
let headers = &[("X-Signal-Receive-Stories", "false")];
let ws = self
.identified_push_service()
.ws("/v1/websocket/", headers, self.credentials(), keep_alive)
.await?;
identified_ws.replace(ws.clone());
debug!("initialized identified websocket");

Ok(ws)
}
}
}

async fn unidentified_websocket(&self) -> Result<SignalWebSocket, Error<S::Error>> {
let mut unidentified_ws = self.state.unidentified_websocket.lock();

Check warning on line 224 in presage/src/manager/registered.rs

View workflow job for this annotation

GitHub Actions / clippy

this `MutexGuard` is held across an `await` point
match unidentified_ws.clone() {
Some(ws) => Ok(ws),
None => {
let keep_alive = true;
let ws = self
.unidentified_push_service()
.ws("/v1/websocket/", &[], None, keep_alive)
.await?;
unidentified_ws.replace(ws.clone());
debug!("initialized unidentified websocket");

Ok(ws)
}
}
}

pub(crate) async fn register_pre_keys(&mut self) -> Result<(), Error<S::Error>> {
trace!("registering pre keys");
let mut account_manager =
AccountManager::new(self.push_service()?, Some(self.state.data.profile_key));
let mut account_manager = AccountManager::new(
self.identified_push_service(),
Some(self.state.data.profile_key),
);

let (pre_keys_offset_id, next_signed_pre_key_id, next_pq_pre_key_id) = account_manager
.update_pre_key_bundle(
Expand All @@ -199,8 +267,10 @@

pub(crate) async fn set_account_attributes(&mut self) -> Result<(), Error<S::Error>> {
trace!("setting account attributes");
let mut account_manager =
AccountManager::new(self.push_service()?, Some(self.state.data.profile_key));
let mut account_manager = AccountManager::new(
self.identified_push_service(),
Some(self.state.data.profile_key),
);

let pni_registration_id =
if let Some(pni_registration_id) = self.state.data.pni_registration_id {
Expand Down Expand Up @@ -251,7 +321,7 @@
&mut self,
mut messages: impl Stream<Item = Content> + Unpin,
) -> Result<(), Error<S::Error>> {
let mut message_receiver = MessageReceiver::new(self.push_service()?);
let mut message_receiver = MessageReceiver::new(self.identified_push_service());
while let Some(Content { body, .. }) = messages.next().await {
if let ContentBody::SynchronizeMessage(SyncMessage {
contacts: Some(contacts),
Expand Down Expand Up @@ -333,7 +403,7 @@

if needs_renewal(self.state.unidentified_sender_certificate.as_ref()) {
let sender_certificate = self
.push_service()?
.identified_push_service()
.get_uuid_only_sender_certificate()
.await?;

Expand All @@ -354,7 +424,7 @@
token: &str,
captcha: &str,
) -> Result<(), Error<S::Error>> {
let mut account_manager = AccountManager::new(self.push_service()?, None);
let mut account_manager = AccountManager::new(self.identified_push_service(), None);
account_manager
.submit_recaptcha_challenge(token, captcha)
.await?;
Expand All @@ -363,7 +433,7 @@

/// Fetches basic information on the registered device.
pub async fn whoami(&self) -> Result<WhoAmIResponse, Error<S::Error>> {
Ok(self.push_service()?.whoami().await?)
Ok(self.identified_push_service().whoami().await?)
}

/// Fetches the profile (name, about, status emoji) of the registered user.
Expand All @@ -383,7 +453,8 @@
return Ok(profile);
}

let mut account_manager = AccountManager::new(self.push_service()?, Some(profile_key));
let mut account_manager =
AccountManager::new(self.identified_push_service(), Some(profile_key));

let profile = account_manager.retrieve_profile(uuid.into()).await?;

Expand All @@ -404,23 +475,8 @@
&mut self,
) -> Result<impl Stream<Item = Result<Incoming, ServiceError>>, Error<S::Error>> {
let credentials = self.credentials().ok_or(Error::NotYetRegisteredError)?;
let allow_stories = false;
let pipe = MessageReceiver::new(self.push_service()?)
.create_message_pipe(credentials, allow_stories)
.await?;

let service_configuration = self.state.service_configuration();
let mut unidentified_push_service =
HyperPushService::new(service_configuration, None, crate::USER_AGENT.to_string());
let unidentified_ws = unidentified_push_service
.ws("/v1/websocket/", &[], None, false)
.await?;
self.state.identified_websocket.lock().replace(pipe.ws());
self.state
.unidentified_websocket
.lock()
.replace(unidentified_ws);

let ws = self.identified_websocket().await?;
let pipe = MessagePipe::from_socket(ws, credentials);
Ok(pipe.stream())
}

Expand Down Expand Up @@ -449,7 +505,7 @@
let groups_credentials_cache = InMemoryCredentialsCache::default();
let groups_manager = GroupsManager::new(
self.state.data.service_ids.clone(),
self.push_service()?,
self.identified_push_service(),
groups_credentials_cache,
server_public_params,
);
Expand All @@ -472,13 +528,15 @@

let init = StreamState {
encrypted_messages: Box::pin(self.receive_messages_encrypted().await?),
message_receiver: MessageReceiver::new(self.push_service()?),
message_receiver: MessageReceiver::new(self.identified_push_service()),
service_cipher: self.new_service_cipher()?,
store: self.store.clone(),
groups_manager: self.groups_manager()?,
mode,
};

debug!("starting to consume incoming message stream");

Ok(futures::stream::unfold(init, |mut state| async move {
loop {
match state.encrypted_messages.next().await {
Expand Down Expand Up @@ -763,7 +821,7 @@
&self,
attachment_pointer: &AttachmentPointer,
) -> Result<Vec<u8>, Error<S::Error>> {
let mut service = self.push_service()?;
let mut service = self.identified_push_service();
let mut attachment_stream = service.get_attachment(attachment_pointer).await?;

// We need the whole file for the crypto to check out
Expand Down Expand Up @@ -804,45 +862,19 @@
})
}

/// Returns a clone of a cached push service.
///
/// If no service is yet cached, it will create and cache one.
fn push_service(&self) -> Result<HyperPushService, Error<S::Error>> {
self.state.push_service_cache.get(|| {
Ok(HyperPushService::new(
self.state.service_configuration(),
self.credentials(),
crate::USER_AGENT.to_string(),
))
})
}

/// Creates a new message sender.
async fn new_message_sender(&self) -> Result<MessageSender<S>, Error<S::Error>> {
let local_addr = ServiceAddress {
uuid: self.state.data.service_ids.aci,
};

let identified_websocket = self
.state
.identified_websocket
.lock()
.clone()
.ok_or(Error::MessagePipeNotStarted)?;

let mut unidentified_push_service = HyperPushService::new(
self.state.service_configuration(),
None,
crate::USER_AGENT.to_string(),
);
let unidentified_websocket = unidentified_push_service
.ws("/v1/websocket/", &[], None, false)
.await?;
let identified_websocket = self.identified_websocket().await?;
let unidentified_websocket = self.unidentified_websocket().await?;

Ok(MessageSender::new(
identified_websocket,
unidentified_websocket,
self.push_service()?,
self.identified_push_service(),
self.new_service_cipher()?,
self.rng.clone(),
self.store.clone(),
Expand Down
Loading